local.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  9    LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
 10};
 11use mistralrs::{
 12    IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
 13    TextModelBuilder,
 14};
 15use serde::{Deserialize, Serialize};
 16use std::sync::Arc;
 17use ui::{ButtonLike, IconName, Indicator, prelude::*};
 18
 19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
 20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
 21const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
 22
 23#[derive(Default, Debug, Clone, PartialEq)]
 24pub struct LocalSettings {
 25    pub available_models: Vec<AvailableModel>,
 26}
 27
 28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
 29pub struct AvailableModel {
 30    pub name: String,
 31    pub display_name: Option<String>,
 32    pub max_tokens: u64,
 33}
 34
 35pub struct LocalLanguageModelProvider {
 36    state: Entity<State>,
 37}
 38
 39pub struct State {
 40    model: Option<Arc<MistralModel>>,
 41    status: ModelStatus,
 42}
 43
 44#[derive(Clone, Debug, PartialEq)]
 45enum ModelStatus {
 46    NotLoaded,
 47    Loading,
 48    Loaded,
 49    Error(String),
 50}
 51
 52impl State {
 53    fn new(_cx: &mut Context<Self>) -> Self {
 54        Self {
 55            model: None,
 56            status: ModelStatus::NotLoaded,
 57        }
 58    }
 59
 60    fn is_authenticated(&self) -> bool {
 61        matches!(self.status, ModelStatus::Loaded)
 62    }
 63
 64    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 65        if self.is_authenticated() {
 66            return Task::ready(Ok(()));
 67        }
 68
 69        if matches!(self.status, ModelStatus::Loading) {
 70            return Task::ready(Ok(()));
 71        }
 72
 73        self.status = ModelStatus::Loading;
 74        cx.notify();
 75
 76        cx.spawn(async move |this, cx| match load_mistral_model().await {
 77            Ok(model) => {
 78                this.update(cx, |state, cx| {
 79                    state.model = Some(model);
 80                    state.status = ModelStatus::Loaded;
 81                    cx.notify();
 82                })?;
 83                Ok(())
 84            }
 85            Err(e) => {
 86                let error_msg = e.to_string();
 87                this.update(cx, |state, cx| {
 88                    state.status = ModelStatus::Error(error_msg.clone());
 89                    cx.notify();
 90                })?;
 91                Err(AuthenticateError::Other(anyhow!(
 92                    "Failed to load model: {}",
 93                    error_msg
 94                )))
 95            }
 96        })
 97    }
 98}
 99
100async fn load_mistral_model() -> Result<Arc<MistralModel>> {
101    let model = TextModelBuilder::new(DEFAULT_MODEL)
102        .with_isq(IsqType::Q4_0)
103        .with_logging()
104        .build()
105        .await?;
106
107    Ok(Arc::new(model))
108}
109
110impl LocalLanguageModelProvider {
111    pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
112        let state = cx.new(State::new);
113        Self { state }
114    }
115}
116
117impl LanguageModelProviderState for LocalLanguageModelProvider {
118    type ObservableEntity = State;
119
120    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
121        Some(self.state.clone())
122    }
123}
124
125impl LanguageModelProvider for LocalLanguageModelProvider {
126    fn id(&self) -> LanguageModelProviderId {
127        PROVIDER_ID
128    }
129
130    fn name(&self) -> LanguageModelProviderName {
131        PROVIDER_NAME
132    }
133
134    fn icon(&self) -> IconName {
135        IconName::Ai
136    }
137
138    fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
139        vec![Arc::new(LocalLanguageModel {
140            state: self.state.clone(),
141            request_limiter: RateLimiter::new(4),
142        })]
143    }
144
145    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
146        self.provided_models(cx).into_iter().next()
147    }
148
149    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
150        self.default_model(cx)
151    }
152
153    fn is_authenticated(&self, cx: &App) -> bool {
154        self.state.read(cx).is_authenticated()
155    }
156
157    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
158        self.state.update(cx, |state, cx| state.authenticate(cx))
159    }
160
161    fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
162        cx.new(|_cx| ConfigurationView {
163            state: self.state.clone(),
164        })
165        .into()
166    }
167
168    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
169        self.state.update(cx, |state, cx| {
170            state.model = None;
171            state.status = ModelStatus::NotLoaded;
172            cx.notify();
173        });
174        Task::ready(Ok(()))
175    }
176}
177
178pub struct LocalLanguageModel {
179    state: Entity<State>,
180    request_limiter: RateLimiter,
181}
182
183impl LocalLanguageModel {
184    fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
185        let mut messages = TextMessages::new();
186
187        for message in &request.messages {
188            let mut text_content = String::new();
189
190            for content in &message.content {
191                match content {
192                    MessageContent::Text(text) => {
193                        text_content.push_str(text);
194                    }
195                    MessageContent::Image { .. } => {
196                        // For now, skip image content
197                        continue;
198                    }
199                    MessageContent::ToolResult { .. } => {
200                        // Skip tool results for now
201                        continue;
202                    }
203                    MessageContent::Thinking { .. } => {
204                        // Skip thinking content
205                        continue;
206                    }
207                    MessageContent::RedactedThinking(_) => {
208                        // Skip redacted thinking
209                        continue;
210                    }
211                    MessageContent::ToolUse(_) => {
212                        // Skip tool use
213                        continue;
214                    }
215                }
216            }
217
218            if text_content.is_empty() {
219                continue;
220            }
221
222            let role = match message.role {
223                Role::User => TextMessageRole::User,
224                Role::Assistant => TextMessageRole::Assistant,
225                Role::System => TextMessageRole::System,
226            };
227
228            messages = messages.add_message(role, text_content);
229        }
230
231        messages
232    }
233}
234
235impl LanguageModel for LocalLanguageModel {
236    fn id(&self) -> LanguageModelId {
237        LanguageModelId(DEFAULT_MODEL.into())
238    }
239
240    fn name(&self) -> LanguageModelName {
241        LanguageModelName(DEFAULT_MODEL.into())
242    }
243
244    fn provider_id(&self) -> LanguageModelProviderId {
245        PROVIDER_ID
246    }
247
248    fn provider_name(&self) -> LanguageModelProviderName {
249        PROVIDER_NAME
250    }
251
252    fn telemetry_id(&self) -> String {
253        format!("local/{}", DEFAULT_MODEL)
254    }
255
256    fn supports_tools(&self) -> bool {
257        false
258    }
259
260    fn supports_images(&self) -> bool {
261        false
262    }
263
264    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
265        false
266    }
267
268    fn max_token_count(&self) -> u64 {
269        128000 // GLM-4.5-Air supports 128k context
270    }
271
272    fn count_tokens(
273        &self,
274        request: LanguageModelRequest,
275        _cx: &App,
276    ) -> BoxFuture<'static, Result<u64>> {
277        // Rough estimation: 1 token ≈ 4 characters
278        let mut total_chars = 0;
279        for message in request.messages {
280            for content in message.content {
281                match content {
282                    MessageContent::Text(text) => total_chars += text.len(),
283                    _ => {}
284                }
285            }
286        }
287        let tokens = (total_chars / 4) as u64;
288        futures::future::ready(Ok(tokens)).boxed()
289    }
290
291    fn stream_completion(
292        &self,
293        request: LanguageModelRequest,
294        cx: &AsyncApp,
295    ) -> BoxFuture<
296        'static,
297        Result<
298            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
299            LanguageModelCompletionError,
300        >,
301    > {
302        let messages = self.to_mistral_messages(&request);
303        let state = self.state.clone();
304        let limiter = self.request_limiter.clone();
305
306        cx.spawn(async move |cx| {
307            let result: Result<
308                BoxStream<
309                    'static,
310                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
311                >,
312                LanguageModelCompletionError,
313            > = limiter
314                .run(async move {
315                    let model = cx
316                        .read_entity(&state, |state, _| state.model.clone())
317                        .map_err(|_| {
318                            LanguageModelCompletionError::Other(anyhow!("App state dropped"))
319                        })?
320                        .ok_or_else(|| {
321                            LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
322                        })?;
323
324                    let (mut tx, rx) = mpsc::channel(32);
325
326                    // Spawn a task to handle the stream
327                    let _ = smol::spawn(async move {
328                        let mut stream = match model.stream_chat_request(messages).await {
329                            Ok(stream) => stream,
330                            Err(e) => {
331                                let _ = tx
332                                    .send(Err(LanguageModelCompletionError::Other(anyhow!(
333                                        "Failed to start stream: {}",
334                                        e
335                                    ))))
336                                    .await;
337                                return;
338                            }
339                        };
340
341                        while let Some(response) = stream.next().await {
342                            let event = match response {
343                                MistralResponse::Chunk(chunk) => {
344                                    if let Some(choice) = chunk.choices.first() {
345                                        if let Some(content) = &choice.delta.content {
346                                            Some(Ok(LanguageModelCompletionEvent::Text(
347                                                content.clone(),
348                                            )))
349                                        } else if let Some(finish_reason) = &choice.finish_reason {
350                                            let stop_reason = match finish_reason.as_str() {
351                                                "stop" => StopReason::EndTurn,
352                                                "length" => StopReason::MaxTokens,
353                                                _ => StopReason::EndTurn,
354                                            };
355                                            Some(Ok(LanguageModelCompletionEvent::Stop(
356                                                stop_reason,
357                                            )))
358                                        } else {
359                                            None
360                                        }
361                                    } else {
362                                        None
363                                    }
364                                }
365                                MistralResponse::Done(_response) => {
366                                    // For now, we don't emit usage events since the format doesn't match
367                                    None
368                                }
369                                _ => None,
370                            };
371
372                            if let Some(event) = event {
373                                if tx.send(event).await.is_err() {
374                                    break;
375                                }
376                            }
377                        }
378                    })
379                    .detach();
380
381                    Ok(rx.boxed())
382                })
383                .await;
384
385            result
386        })
387        .boxed()
388    }
389}
390
391struct ConfigurationView {
392    state: Entity<State>,
393}
394
395impl Render for ConfigurationView {
396    fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
397        let status = self.state.read(cx).status.clone();
398
399        div().size_full().child(
400            div()
401                .p_4()
402                .child(
403                    div()
404                        .flex()
405                        .gap_2()
406                        .items_center()
407                        .child(match &status {
408                            ModelStatus::NotLoaded => Label::new("Model not loaded"),
409                            ModelStatus::Loading => Label::new("Loading model..."),
410                            ModelStatus::Loaded => Label::new("Model loaded"),
411                            ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
412                        })
413                        .child(match &status {
414                            ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
415                            ModelStatus::Loading => Indicator::dot().color(Color::Modified),
416                            ModelStatus::Loaded => Indicator::dot().color(Color::Success),
417                            ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
418                        }),
419                )
420                .when(!matches!(status, ModelStatus::Loading), |this| {
421                    this.child(
422                        ButtonLike::new("load_model")
423                            .child(Label::new(if matches!(status, ModelStatus::Loaded) {
424                                "Reload Model"
425                            } else {
426                                "Load Model"
427                            }))
428                            .on_click(cx.listener(|this, _, _window, cx| {
429                                this.state.update(cx, |state, cx| {
430                                    state.authenticate(cx).detach();
431                                });
432                            })),
433                    )
434                }),
435        )
436    }
437}
438
439#[cfg(test)]
440mod tests;