lmstudio.rs

   1use anyhow::{Result, anyhow};
   2use collections::HashMap;
   3use credentials_provider::CredentialsProvider;
   4use fs::Fs;
   5use futures::Stream;
   6use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
   7use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Subscription, Task};
   8use http_client::HttpClient;
   9use language_model::{
  10    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  11    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolResultContent,
  12    LanguageModelToolUse, MessageContent, StopReason, TokenUsage, env_var,
  13};
  14use language_model::{
  15    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  16    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
  17};
  18use lmstudio::{LMSTUDIO_API_URL, ModelType, get_models};
  19
  20pub use settings::LmStudioAvailableModel as AvailableModel;
  21use settings::{Settings, SettingsStore, update_settings_file};
  22use std::pin::Pin;
  23use std::sync::LazyLock;
  24use std::{collections::BTreeMap, sync::Arc};
  25use ui::{
  26    ButtonLike, ConfiguredApiCard, ElevationIndex, List, ListBulletItem, Tooltip, prelude::*,
  27};
  28use ui_input::InputField;
  29
  30use crate::AllLanguageModelSettings;
  31use language_model::util::parse_tool_arguments;
  32
  33const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
  34const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
  35const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
  36
  37const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
  38const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
  39
  40const API_KEY_ENV_VAR_NAME: &str = "LMSTUDIO_API_KEY";
  41static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  42
  43#[derive(Default, Debug, Clone, PartialEq)]
  44pub struct LmStudioSettings {
  45    pub api_url: String,
  46    pub available_models: Vec<AvailableModel>,
  47}
  48
  49pub struct LmStudioLanguageModelProvider {
  50    http_client: Arc<dyn HttpClient>,
  51    state: Entity<State>,
  52}
  53
  54pub struct State {
  55    api_key_state: ApiKeyState,
  56    credentials_provider: Arc<dyn CredentialsProvider>,
  57    http_client: Arc<dyn HttpClient>,
  58    available_models: Vec<lmstudio::Model>,
  59    fetch_model_task: Option<Task<Result<()>>>,
  60    _subscription: Subscription,
  61}
  62
  63impl State {
  64    fn is_authenticated(&self) -> bool {
  65        !self.available_models.is_empty()
  66    }
  67
  68    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  69        let credentials_provider = self.credentials_provider.clone();
  70        let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
  71        let task = self.api_key_state.store(
  72            api_url,
  73            api_key,
  74            |this| &mut this.api_key_state,
  75            credentials_provider,
  76            cx,
  77        );
  78        self.restart_fetch_models_task(cx);
  79        task
  80    }
  81
  82    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
  83        let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
  84        let http_client = self.http_client.clone();
  85        let api_url = settings.api_url.clone();
  86        let api_key = self.api_key_state.key(&api_url);
  87
  88        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
  89        cx.spawn(async move |this, cx| {
  90            let models =
  91                get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?;
  92
  93            let mut models: Vec<lmstudio::Model> = models
  94                .into_iter()
  95                .filter(|model| model.r#type != ModelType::Embeddings)
  96                .map(|model| {
  97                    lmstudio::Model::new(
  98                        &model.id,
  99                        None,
 100                        model
 101                            .loaded_context_length
 102                            .or_else(|| model.max_context_length),
 103                        model.capabilities.supports_tool_calls(),
 104                        model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
 105                    )
 106                })
 107                .collect();
 108
 109            models.sort_by(|a, b| a.name.cmp(&b.name));
 110
 111            this.update(cx, |this, cx| {
 112                this.available_models = models;
 113                cx.notify();
 114            })
 115        })
 116    }
 117
 118    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 119        let task = self.fetch_models(cx);
 120        self.fetch_model_task.replace(task);
 121    }
 122
 123    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 124        let credentials_provider = self.credentials_provider.clone();
 125        let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
 126        let _task = self.api_key_state.load_if_needed(
 127            api_url,
 128            |this| &mut this.api_key_state,
 129            credentials_provider,
 130            cx,
 131        );
 132
 133        if self.is_authenticated() {
 134            return Task::ready(Ok(()));
 135        }
 136
 137        let fetch_models_task = self.fetch_models(cx);
 138        cx.spawn(async move |_this, _cx| {
 139            match fetch_models_task.await {
 140                Ok(()) => Ok(()),
 141                Err(err) => {
 142                    // If any cause in the error chain is an std::io::Error with
 143                    // ErrorKind::ConnectionRefused, treat this as "credentials not found"
 144                    // (i.e. LM Studio not running).
 145                    let mut connection_refused = false;
 146                    for cause in err.chain() {
 147                        if let Some(io_err) = cause.downcast_ref::<std::io::Error>() {
 148                            if io_err.kind() == std::io::ErrorKind::ConnectionRefused {
 149                                connection_refused = true;
 150                                break;
 151                            }
 152                        }
 153                    }
 154                    if connection_refused {
 155                        Err(AuthenticateError::ConnectionRefused)
 156                    } else {
 157                        Err(AuthenticateError::Other(err))
 158                    }
 159                }
 160            }
 161        })
 162    }
 163}
 164
 165impl LmStudioLanguageModelProvider {
 166    pub fn new(
 167        http_client: Arc<dyn HttpClient>,
 168        credentials_provider: Arc<dyn CredentialsProvider>,
 169        cx: &mut App,
 170    ) -> Self {
 171        let this = Self {
 172            http_client: http_client.clone(),
 173            state: cx.new(|cx| {
 174                let subscription = cx.observe_global::<SettingsStore>({
 175                    let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
 176                    move |this: &mut State, cx| {
 177                        let new_settings =
 178                            AllLanguageModelSettings::get_global(cx).lmstudio.clone();
 179                        if settings != new_settings {
 180                            let credentials_provider = this.credentials_provider.clone();
 181                            let api_url = Self::api_url(cx).into();
 182                            this.api_key_state.handle_url_change(
 183                                api_url,
 184                                |this| &mut this.api_key_state,
 185                                credentials_provider,
 186                                cx,
 187                            );
 188                            settings = new_settings;
 189                            this.restart_fetch_models_task(cx);
 190                            cx.notify();
 191                        }
 192                    }
 193                });
 194
 195                State {
 196                    api_key_state: ApiKeyState::new(
 197                        Self::api_url(cx).into(),
 198                        (*API_KEY_ENV_VAR).clone(),
 199                    ),
 200                    credentials_provider,
 201                    http_client,
 202                    available_models: Default::default(),
 203                    fetch_model_task: None,
 204                    _subscription: subscription,
 205                }
 206            }),
 207        };
 208        this.state
 209            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
 210        this
 211    }
 212
 213    fn api_url(cx: &App) -> String {
 214        AllLanguageModelSettings::get_global(cx)
 215            .lmstudio
 216            .api_url
 217            .clone()
 218    }
 219
 220    fn has_custom_url(cx: &App) -> bool {
 221        Self::api_url(cx) != LMSTUDIO_API_URL
 222    }
 223}
 224
 225impl LanguageModelProviderState for LmStudioLanguageModelProvider {
 226    type ObservableEntity = State;
 227
 228    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 229        Some(self.state.clone())
 230    }
 231}
 232
 233impl LanguageModelProvider for LmStudioLanguageModelProvider {
 234    fn id(&self) -> LanguageModelProviderId {
 235        PROVIDER_ID
 236    }
 237
 238    fn name(&self) -> LanguageModelProviderName {
 239        PROVIDER_NAME
 240    }
 241
 242    fn icon(&self) -> IconOrSvg {
 243        IconOrSvg::Icon(IconName::AiLmStudio)
 244    }
 245
 246    fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 247        // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
 248        // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
 249        // to load by default.
 250        None
 251    }
 252
 253    fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 254        // See explanation for default_model.
 255        None
 256    }
 257
 258    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 259        let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
 260
 261        // Add models from the LM Studio API
 262        for model in self.state.read(cx).available_models.iter() {
 263            models.insert(model.name.clone(), model.clone());
 264        }
 265
 266        // Override with available models from settings
 267        for model in AllLanguageModelSettings::get_global(cx)
 268            .lmstudio
 269            .available_models
 270            .iter()
 271        {
 272            models.insert(
 273                model.name.clone(),
 274                lmstudio::Model {
 275                    name: model.name.clone(),
 276                    display_name: model.display_name.clone(),
 277                    max_tokens: model.max_tokens,
 278                    supports_tool_calls: model.supports_tool_calls,
 279                    supports_images: model.supports_images,
 280                },
 281            );
 282        }
 283
 284        models
 285            .into_values()
 286            .map(|model| {
 287                Arc::new(LmStudioLanguageModel {
 288                    id: LanguageModelId::from(model.name.clone()),
 289                    model,
 290                    http_client: self.http_client.clone(),
 291                    request_limiter: RateLimiter::new(4),
 292                    state: self.state.clone(),
 293                }) as Arc<dyn LanguageModel>
 294            })
 295            .collect()
 296    }
 297
 298    fn is_authenticated(&self, cx: &App) -> bool {
 299        self.state.read(cx).is_authenticated()
 300    }
 301
 302    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 303        self.state.update(cx, |state, cx| state.authenticate(cx))
 304    }
 305
 306    fn configuration_view(
 307        &self,
 308        _target_agent: language_model::ConfigurationViewTargetAgent,
 309        _window: &mut Window,
 310        cx: &mut App,
 311    ) -> AnyView {
 312        cx.new(|cx| ConfigurationView::new(self.state.clone(), _window, cx))
 313            .into()
 314    }
 315
 316    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 317        self.state
 318            .update(cx, |state, cx| state.set_api_key(None, cx))
 319    }
 320}
 321
 322pub struct LmStudioLanguageModel {
 323    id: LanguageModelId,
 324    model: lmstudio::Model,
 325    http_client: Arc<dyn HttpClient>,
 326    request_limiter: RateLimiter,
 327    state: Entity<State>,
 328}
 329
 330impl LmStudioLanguageModel {
 331    fn to_lmstudio_request(
 332        &self,
 333        request: LanguageModelRequest,
 334    ) -> lmstudio::ChatCompletionRequest {
 335        let mut messages = Vec::new();
 336
 337        for message in request.messages {
 338            for content in message.content {
 339                match content {
 340                    MessageContent::Text(text) => add_message_content_part(
 341                        lmstudio::MessagePart::Text { text },
 342                        message.role,
 343                        &mut messages,
 344                    ),
 345                    MessageContent::Thinking { .. } => {}
 346                    MessageContent::RedactedThinking(_) => {}
 347                    MessageContent::Image(image) => {
 348                        add_message_content_part(
 349                            lmstudio::MessagePart::Image {
 350                                image_url: lmstudio::ImageUrl {
 351                                    url: image.to_base64_url(),
 352                                    detail: None,
 353                                },
 354                            },
 355                            message.role,
 356                            &mut messages,
 357                        );
 358                    }
 359                    MessageContent::ToolUse(tool_use) => {
 360                        let tool_call = lmstudio::ToolCall {
 361                            id: tool_use.id.to_string(),
 362                            content: lmstudio::ToolCallContent::Function {
 363                                function: lmstudio::FunctionContent {
 364                                    name: tool_use.name.to_string(),
 365                                    arguments: serde_json::to_string(&tool_use.input)
 366                                        .unwrap_or_default(),
 367                                },
 368                            },
 369                        };
 370
 371                        if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
 372                            messages.last_mut()
 373                        {
 374                            tool_calls.push(tool_call);
 375                        } else {
 376                            messages.push(lmstudio::ChatMessage::Assistant {
 377                                content: None,
 378                                tool_calls: vec![tool_call],
 379                            });
 380                        }
 381                    }
 382                    MessageContent::ToolResult(tool_result) => {
 383                        let content = match &tool_result.content {
 384                            LanguageModelToolResultContent::Text(text) => {
 385                                vec![lmstudio::MessagePart::Text {
 386                                    text: text.to_string(),
 387                                }]
 388                            }
 389                            LanguageModelToolResultContent::Image(image) => {
 390                                vec![lmstudio::MessagePart::Image {
 391                                    image_url: lmstudio::ImageUrl {
 392                                        url: image.to_base64_url(),
 393                                        detail: None,
 394                                    },
 395                                }]
 396                            }
 397                        };
 398
 399                        messages.push(lmstudio::ChatMessage::Tool {
 400                            content: content.into(),
 401                            tool_call_id: tool_result.tool_use_id.to_string(),
 402                        });
 403                    }
 404                }
 405            }
 406        }
 407
 408        lmstudio::ChatCompletionRequest {
 409            model: self.model.name.clone(),
 410            messages,
 411            stream: true,
 412            max_tokens: Some(-1),
 413            stop: Some(request.stop),
 414            // In LM Studio you can configure specific settings you'd like to use for your model.
 415            // For example Qwen3 is recommended to be used with 0.7 temperature.
 416            // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
 417            temperature: request.temperature.or(None),
 418            tools: request
 419                .tools
 420                .into_iter()
 421                .map(|tool| lmstudio::ToolDefinition::Function {
 422                    function: lmstudio::FunctionDefinition {
 423                        name: tool.name,
 424                        description: Some(tool.description),
 425                        parameters: Some(tool.input_schema),
 426                    },
 427                })
 428                .collect(),
 429            tool_choice: request.tool_choice.map(|choice| match choice {
 430                LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
 431                LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
 432                LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
 433            }),
 434        }
 435    }
 436
 437    fn stream_completion(
 438        &self,
 439        request: lmstudio::ChatCompletionRequest,
 440        cx: &AsyncApp,
 441    ) -> BoxFuture<
 442        'static,
 443        Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
 444    > {
 445        let http_client = self.http_client.clone();
 446        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 447            let api_url = LmStudioLanguageModelProvider::api_url(cx);
 448            (state.api_key_state.key(&api_url), api_url)
 449        });
 450
 451        let future = self.request_limiter.stream(async move {
 452            let stream = lmstudio::stream_chat_completion(
 453                http_client.as_ref(),
 454                &api_url,
 455                api_key.as_deref(),
 456                request,
 457            )
 458            .await?;
 459            Ok(stream)
 460        });
 461
 462        async move { Ok(future.await?.boxed()) }.boxed()
 463    }
 464}
 465
 466impl LanguageModel for LmStudioLanguageModel {
 467    fn id(&self) -> LanguageModelId {
 468        self.id.clone()
 469    }
 470
 471    fn name(&self) -> LanguageModelName {
 472        LanguageModelName::from(self.model.display_name().to_string())
 473    }
 474
 475    fn provider_id(&self) -> LanguageModelProviderId {
 476        PROVIDER_ID
 477    }
 478
 479    fn provider_name(&self) -> LanguageModelProviderName {
 480        PROVIDER_NAME
 481    }
 482
 483    fn supports_tools(&self) -> bool {
 484        self.model.supports_tool_calls()
 485    }
 486
 487    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 488        self.supports_tools()
 489            && match choice {
 490                LanguageModelToolChoice::Auto => true,
 491                LanguageModelToolChoice::Any => true,
 492                LanguageModelToolChoice::None => true,
 493            }
 494    }
 495
 496    fn supports_images(&self) -> bool {
 497        self.model.supports_images
 498    }
 499
 500    fn telemetry_id(&self) -> String {
 501        format!("lmstudio/{}", self.model.id())
 502    }
 503
 504    fn max_token_count(&self) -> u64 {
 505        self.model.max_token_count()
 506    }
 507
 508    fn count_tokens(
 509        &self,
 510        request: LanguageModelRequest,
 511        _cx: &App,
 512    ) -> BoxFuture<'static, Result<u64>> {
 513        // Endpoint for this is coming soon. In the meantime, hacky estimation
 514        let token_count = request
 515            .messages
 516            .iter()
 517            .map(|msg| msg.string_contents().split_whitespace().count())
 518            .sum::<usize>();
 519
 520        let estimated_tokens = (token_count as f64 * 0.75) as u64;
 521        async move { Ok(estimated_tokens) }.boxed()
 522    }
 523
 524    fn stream_completion(
 525        &self,
 526        request: LanguageModelRequest,
 527        cx: &AsyncApp,
 528    ) -> BoxFuture<
 529        'static,
 530        Result<
 531            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 532            LanguageModelCompletionError,
 533        >,
 534    > {
 535        let request = self.to_lmstudio_request(request);
 536        let completions = self.stream_completion(request, cx);
 537        async move {
 538            let mapper = LmStudioEventMapper::new();
 539            Ok(mapper.map_stream(completions.await?).boxed())
 540        }
 541        .boxed()
 542    }
 543}
 544
 545struct LmStudioEventMapper {
 546    tool_calls_by_index: HashMap<usize, RawToolCall>,
 547}
 548
 549impl LmStudioEventMapper {
 550    fn new() -> Self {
 551        Self {
 552            tool_calls_by_index: HashMap::default(),
 553        }
 554    }
 555
 556    pub fn map_stream(
 557        mut self,
 558        events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
 559    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 560    {
 561        events.flat_map(move |event| {
 562            futures::stream::iter(match event {
 563                Ok(event) => self.map_event(event),
 564                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 565            })
 566        })
 567    }
 568
 569    pub fn map_event(
 570        &mut self,
 571        event: lmstudio::ResponseStreamEvent,
 572    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 573        let Some(choice) = event.choices.into_iter().next() else {
 574            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 575                "Response contained no choices"
 576            )))];
 577        };
 578
 579        let mut events = Vec::new();
 580        if let Some(content) = choice.delta.content {
 581            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 582        }
 583
 584        if let Some(reasoning_content) = choice.delta.reasoning_content {
 585            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 586                text: reasoning_content,
 587                signature: None,
 588            }));
 589        }
 590
 591        if let Some(tool_calls) = choice.delta.tool_calls {
 592            for tool_call in tool_calls {
 593                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 594
 595                if let Some(tool_id) = tool_call.id {
 596                    entry.id = tool_id;
 597                }
 598
 599                if let Some(function) = tool_call.function {
 600                    if let Some(name) = function.name {
 601                        // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
 602                        // 1. It sends function name in the first chunk
 603                        // 2. It sends empty string in the function name field in all subsequent chunks for arguments
 604                        // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
 605                        // function name field should be sent only inside the first chunk.
 606                        if !name.is_empty() {
 607                            entry.name = name;
 608                        }
 609                    }
 610
 611                    if let Some(arguments) = function.arguments {
 612                        entry.arguments.push_str(&arguments);
 613                    }
 614                }
 615            }
 616        }
 617
 618        if let Some(usage) = event.usage {
 619            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 620                input_tokens: usage.prompt_tokens,
 621                output_tokens: usage.completion_tokens,
 622                cache_creation_input_tokens: 0,
 623                cache_read_input_tokens: 0,
 624            })));
 625        }
 626
 627        match choice.finish_reason.as_deref() {
 628            Some("stop") => {
 629                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 630            }
 631            Some("tool_calls") => {
 632                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
 633                    match parse_tool_arguments(&tool_call.arguments) {
 634                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 635                            LanguageModelToolUse {
 636                                id: tool_call.id.into(),
 637                                name: tool_call.name.into(),
 638                                is_input_complete: true,
 639                                input,
 640                                raw_input: tool_call.arguments,
 641                                thought_signature: None,
 642                            },
 643                        )),
 644                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 645                            id: tool_call.id.into(),
 646                            tool_name: tool_call.name.into(),
 647                            raw_input: tool_call.arguments.into(),
 648                            json_parse_error: error.to_string(),
 649                        }),
 650                    }
 651                }));
 652
 653                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 654            }
 655            Some(stop_reason) => {
 656                log::error!("Unexpected LMStudio stop_reason: {stop_reason:?}",);
 657                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 658            }
 659            None => {}
 660        }
 661
 662        events
 663    }
 664}
 665
 666#[derive(Default)]
 667struct RawToolCall {
 668    id: String,
 669    name: String,
 670    arguments: String,
 671}
 672
 673fn add_message_content_part(
 674    new_part: lmstudio::MessagePart,
 675    role: Role,
 676    messages: &mut Vec<lmstudio::ChatMessage>,
 677) {
 678    match (role, messages.last_mut()) {
 679        (Role::User, Some(lmstudio::ChatMessage::User { content }))
 680        | (
 681            Role::Assistant,
 682            Some(lmstudio::ChatMessage::Assistant {
 683                content: Some(content),
 684                ..
 685            }),
 686        )
 687        | (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
 688            content.push_part(new_part);
 689        }
 690        _ => {
 691            messages.push(match role {
 692                Role::User => lmstudio::ChatMessage::User {
 693                    content: lmstudio::MessageContent::from(vec![new_part]),
 694                },
 695                Role::Assistant => lmstudio::ChatMessage::Assistant {
 696                    content: Some(lmstudio::MessageContent::from(vec![new_part])),
 697                    tool_calls: Vec::new(),
 698                },
 699                Role::System => lmstudio::ChatMessage::System {
 700                    content: lmstudio::MessageContent::from(vec![new_part]),
 701                },
 702            });
 703        }
 704    }
 705}
 706
 707struct ConfigurationView {
 708    state: Entity<State>,
 709    api_key_editor: Entity<InputField>,
 710    api_url_editor: Entity<InputField>,
 711}
 712
 713impl ConfigurationView {
 714    pub fn new(state: Entity<State>, _window: &mut Window, cx: &mut Context<Self>) -> Self {
 715        let api_key_editor = cx.new(|cx| InputField::new(_window, cx, "sk-...").label("API key"));
 716
 717        let api_url_editor = cx.new(|cx| {
 718            let input = InputField::new(_window, cx, LMSTUDIO_API_URL).label("API URL");
 719            input.set_text(&LmStudioLanguageModelProvider::api_url(cx), _window, cx);
 720            input
 721        });
 722
 723        cx.observe(&state, |_, _, cx| {
 724            cx.notify();
 725        })
 726        .detach();
 727
 728        Self {
 729            state,
 730            api_key_editor,
 731            api_url_editor,
 732        }
 733    }
 734
 735    fn retry_connection(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
 736        let has_api_url = LmStudioLanguageModelProvider::has_custom_url(cx);
 737        let has_api_key = self
 738            .state
 739            .read_with(cx, |state, _| state.api_key_state.has_key());
 740        if !has_api_url {
 741            self.save_api_url(cx);
 742        }
 743        if !has_api_key {
 744            self.save_api_key(&Default::default(), _window, cx);
 745        }
 746
 747        self.state.update(cx, |state, cx| {
 748            state.restart_fetch_models_task(cx);
 749        });
 750    }
 751
 752    fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
 753        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 754        if api_key.is_empty() {
 755            return;
 756        }
 757
 758        self.api_key_editor
 759            .update(cx, |input, cx| input.set_text("", _window, cx));
 760
 761        let state = self.state.clone();
 762        cx.spawn_in(_window, async move |_, cx| {
 763            state
 764                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 765                .await
 766        })
 767        .detach_and_log_err(cx);
 768    }
 769
 770    fn reset_api_key(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
 771        self.api_key_editor
 772            .update(cx, |input, cx| input.set_text("", _window, cx));
 773
 774        let state = self.state.clone();
 775        cx.spawn_in(_window, async move |_, cx| {
 776            state
 777                .update(cx, |state, cx| state.set_api_key(None, cx))
 778                .await
 779        })
 780        .detach_and_log_err(cx);
 781
 782        cx.notify();
 783    }
 784
 785    fn save_api_url(&self, cx: &mut Context<Self>) {
 786        let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
 787        let current_url = LmStudioLanguageModelProvider::api_url(cx);
 788        if !api_url.is_empty() && &api_url != &current_url {
 789            self.state
 790                .update(cx, |state, cx| state.set_api_key(None, cx))
 791                .detach_and_log_err(cx);
 792
 793            let fs = <dyn Fs>::global(cx);
 794            update_settings_file(fs, cx, move |settings, _| {
 795                settings
 796                    .language_models
 797                    .get_or_insert_default()
 798                    .lmstudio
 799                    .get_or_insert_default()
 800                    .api_url = Some(api_url);
 801            });
 802        }
 803    }
 804
 805    fn reset_api_url(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
 806        self.api_url_editor
 807            .update(cx, |input, cx| input.set_text("", _window, cx));
 808
 809        // Clear API key when URL changes since keys are URL-specific
 810        self.state
 811            .update(cx, |state, cx| state.set_api_key(None, cx))
 812            .detach_and_log_err(cx);
 813
 814        let fs = <dyn Fs>::global(cx);
 815        update_settings_file(fs, cx, |settings, _cx| {
 816            if let Some(settings) = settings
 817                .language_models
 818                .as_mut()
 819                .and_then(|models| models.lmstudio.as_mut())
 820            {
 821                settings.api_url = Some(LMSTUDIO_API_URL.into());
 822            }
 823        });
 824        cx.notify();
 825    }
 826
 827    fn render_api_url_editor(&self, cx: &Context<Self>) -> impl IntoElement {
 828        let api_url = LmStudioLanguageModelProvider::api_url(cx);
 829        let custom_api_url_set = api_url != LMSTUDIO_API_URL;
 830
 831        if custom_api_url_set {
 832            h_flex()
 833                .p_3()
 834                .justify_between()
 835                .rounded_md()
 836                .border_1()
 837                .border_color(cx.theme().colors().border)
 838                .bg(cx.theme().colors().elevated_surface_background)
 839                .child(
 840                    h_flex()
 841                        .gap_2()
 842                        .child(Icon::new(IconName::Check).color(Color::Success))
 843                        .child(v_flex().gap_1().child(Label::new(api_url))),
 844                )
 845                .child(
 846                    Button::new("reset-api-url", "Reset API URL")
 847                        .label_size(LabelSize::Small)
 848                        .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
 849                        .layer(ElevationIndex::ModalSurface)
 850                        .on_click(
 851                            cx.listener(|this, _, _window, cx| this.reset_api_url(_window, cx)),
 852                        ),
 853                )
 854                .into_any_element()
 855        } else {
 856            v_flex()
 857                .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
 858                    this.save_api_url(cx);
 859                    cx.notify();
 860                }))
 861                .gap_2()
 862                .child(self.api_url_editor.clone())
 863                .into_any_element()
 864        }
 865    }
 866
 867    fn render_api_key_editor(&self, cx: &Context<Self>) -> impl IntoElement {
 868        let state = self.state.read(cx);
 869        let env_var_set = state.api_key_state.is_from_env_var();
 870        let configured_card_label = if env_var_set {
 871            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
 872        } else {
 873            "API key configured".to_string()
 874        };
 875
 876        if !state.api_key_state.has_key() {
 877            v_flex()
 878                .on_action(cx.listener(Self::save_api_key))
 879                .child(self.api_key_editor.clone())
 880                .child(
 881                    Label::new(format!(
 882                        "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
 883                    ))
 884                    .size(LabelSize::Small)
 885                    .color(Color::Muted),
 886                )
 887                .into_any_element()
 888        } else {
 889            ConfiguredApiCard::new(configured_card_label)
 890                .disabled(env_var_set)
 891                .on_click(cx.listener(|this, _, _window, cx| this.reset_api_key(_window, cx)))
 892                .when(env_var_set, |this| {
 893                    this.tooltip_label(format!(
 894                        "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
 895                    ))
 896                })
 897                .into_any_element()
 898        }
 899    }
 900}
 901
 902impl Render for ConfigurationView {
 903    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 904        let is_authenticated = self.state.read(cx).is_authenticated();
 905
 906        v_flex()
 907            .gap_2()
 908            .child(
 909                v_flex()
 910                    .gap_1()
 911                    .child(Label::new("Run local LLMs like Llama, Phi, and Qwen."))
 912                    .child(
 913                        List::new()
 914                            .child(ListBulletItem::new(
 915                                "LM Studio needs to be running with at least one model downloaded.",
 916                            ))
 917                            .child(
 918                                ListBulletItem::new("")
 919                                    .child(Label::new("To get your first model, try running"))
 920                                    .child(Label::new("lms get qwen2.5-coder-7b").inline_code(cx)),
 921                            ),
 922                    )
 923                    .child(Label::new(
 924                        "Alternatively, you can connect to an LM Studio server by specifying its \
 925                        URL and API key (may not be required):",
 926                    )),
 927            )
 928            .child(self.render_api_url_editor(cx))
 929            .child(self.render_api_key_editor(cx))
 930            .child(
 931                h_flex()
 932                    .w_full()
 933                    .justify_between()
 934                    .gap_2()
 935                    .child(
 936                        h_flex()
 937                            .w_full()
 938                            .gap_2()
 939                            .map(|this| {
 940                                if is_authenticated {
 941                                    this.child(
 942                                        Button::new("lmstudio-site", "LM Studio")
 943                                            .style(ButtonStyle::Subtle)
 944                                            .end_icon(
 945                                                Icon::new(IconName::ArrowUpRight)
 946                                                    .size(IconSize::Small)
 947                                                    .color(Color::Muted),
 948                                            )
 949                                            .on_click(move |_, _window, cx| {
 950                                                cx.open_url(LMSTUDIO_SITE)
 951                                            })
 952                                            .into_any_element(),
 953                                    )
 954                                } else {
 955                                    this.child(
 956                                        Button::new(
 957                                            "download_lmstudio_button",
 958                                            "Download LM Studio",
 959                                        )
 960                                        .style(ButtonStyle::Subtle)
 961                                        .end_icon(
 962                                            Icon::new(IconName::ArrowUpRight)
 963                                                .size(IconSize::Small)
 964                                                .color(Color::Muted),
 965                                        )
 966                                        .on_click(move |_, _window, cx| {
 967                                            cx.open_url(LMSTUDIO_DOWNLOAD_URL)
 968                                        })
 969                                        .into_any_element(),
 970                                    )
 971                                }
 972                            })
 973                            .child(
 974                                Button::new("view-models", "Model Catalog")
 975                                    .style(ButtonStyle::Subtle)
 976                                    .end_icon(
 977                                        Icon::new(IconName::ArrowUpRight)
 978                                            .size(IconSize::Small)
 979                                            .color(Color::Muted),
 980                                    )
 981                                    .on_click(move |_, _window, cx| {
 982                                        cx.open_url(LMSTUDIO_CATALOG_URL)
 983                                    }),
 984                            ),
 985                    )
 986                    .map(|this| {
 987                        if is_authenticated {
 988                            this.child(
 989                                ButtonLike::new("connected")
 990                                    .disabled(true)
 991                                    .cursor_style(CursorStyle::Arrow)
 992                                    .child(
 993                                        h_flex()
 994                                            .gap_2()
 995                                            .child(Icon::new(IconName::Check).color(Color::Success))
 996                                            .child(Label::new("Connected"))
 997                                            .into_any_element(),
 998                                    )
 999                                    .child(
1000                                        IconButton::new("refresh-models", IconName::RotateCcw)
1001                                            .tooltip(Tooltip::text("Refresh Models"))
1002                                            .on_click(cx.listener(|this, _, _window, cx| {
1003                                                this.state.update(cx, |state, _| {
1004                                                    state.available_models.clear();
1005                                                });
1006                                                this.retry_connection(_window, cx);
1007                                            })),
1008                                    ),
1009                            )
1010                        } else {
1011                            this.child(
1012                                Button::new("retry_lmstudio_models", "Connect")
1013                                    .start_icon(
1014                                        Icon::new(IconName::PlayFilled).size(IconSize::XSmall),
1015                                    )
1016                                    .on_click(cx.listener(move |this, _, _window, cx| {
1017                                        this.retry_connection(_window, cx)
1018                                    })),
1019                            )
1020                        }
1021                    }),
1022            )
1023    }
1024}