local.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  9    LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
 10};
 11use mistralrs::{
 12    IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
 13    TextModelBuilder,
 14};
 15use serde::{Deserialize, Serialize};
 16use std::sync::Arc;
 17use ui::{ButtonLike, IconName, Indicator, prelude::*};
 18
 19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
 20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
 21const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
 22
 23#[derive(Default, Debug, Clone, PartialEq)]
 24pub struct LocalSettings {
 25    pub available_models: Vec<AvailableModel>,
 26}
 27
 28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
 29pub struct AvailableModel {
 30    pub name: String,
 31    pub display_name: Option<String>,
 32    pub max_tokens: u64,
 33}
 34
 35pub struct LocalLanguageModelProvider {
 36    state: Entity<State>,
 37}
 38
 39pub struct State {
 40    model: Option<Arc<MistralModel>>,
 41    status: ModelStatus,
 42}
 43
 44#[derive(Clone, Debug, PartialEq)]
 45enum ModelStatus {
 46    NotLoaded,
 47    Loading,
 48    Loaded,
 49    Error(String),
 50}
 51
 52impl State {
 53    fn new(_cx: &mut Context<Self>) -> Self {
 54        Self {
 55            model: None,
 56            status: ModelStatus::NotLoaded,
 57        }
 58    }
 59
 60    fn is_authenticated(&self) -> bool {
 61        // Local models don't require authentication
 62        true
 63    }
 64
 65    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 66        if self.is_authenticated() {
 67            return Task::ready(Ok(()));
 68        }
 69
 70        if matches!(self.status, ModelStatus::Loading) {
 71            return Task::ready(Ok(()));
 72        }
 73
 74        self.status = ModelStatus::Loading;
 75        cx.notify();
 76
 77        cx.spawn(async move |this, cx| match load_mistral_model().await {
 78            Ok(model) => {
 79                this.update(cx, |state, cx| {
 80                    state.model = Some(model);
 81                    state.status = ModelStatus::Loaded;
 82                    cx.notify();
 83                })?;
 84                Ok(())
 85            }
 86            Err(e) => {
 87                let error_msg = e.to_string();
 88                this.update(cx, |state, cx| {
 89                    state.status = ModelStatus::Error(error_msg.clone());
 90                    cx.notify();
 91                })?;
 92                Err(AuthenticateError::Other(anyhow!(
 93                    "Failed to load model: {}",
 94                    error_msg
 95                )))
 96            }
 97        })
 98    }
 99}
100
101async fn load_mistral_model() -> Result<Arc<MistralModel>> {
102    println!("\n\n\n\nLoading mistral model...\n\n\n");
103    let model = TextModelBuilder::new(DEFAULT_MODEL)
104        .with_isq(IsqType::Q4_0)
105        .build()
106        .await?;
107
108    Ok(Arc::new(model))
109}
110
111impl LocalLanguageModelProvider {
112    pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
113        let state = cx.new(State::new);
114        Self { state }
115    }
116}
117
118impl LanguageModelProviderState for LocalLanguageModelProvider {
119    type ObservableEntity = State;
120
121    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
122        Some(self.state.clone())
123    }
124}
125
126impl LanguageModelProvider for LocalLanguageModelProvider {
127    fn id(&self) -> LanguageModelProviderId {
128        PROVIDER_ID
129    }
130
131    fn name(&self) -> LanguageModelProviderName {
132        PROVIDER_NAME
133    }
134
135    fn icon(&self) -> IconName {
136        IconName::Ai
137    }
138
139    fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
140        vec![Arc::new(LocalLanguageModel {
141            state: self.state.clone(),
142            request_limiter: RateLimiter::new(4),
143        })]
144    }
145
146    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
147        self.provided_models(cx).into_iter().next()
148    }
149
150    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
151        self.default_model(cx)
152    }
153
154    fn is_authenticated(&self, _cx: &App) -> bool {
155        // Local models don't require authentication
156        true
157    }
158
159    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
160        self.state.update(cx, |state, cx| state.authenticate(cx))
161    }
162
163    fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
164        cx.new(|_cx| ConfigurationView {
165            state: self.state.clone(),
166        })
167        .into()
168    }
169
170    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
171        self.state.update(cx, |state, cx| {
172            state.model = None;
173            state.status = ModelStatus::NotLoaded;
174            cx.notify();
175        });
176        Task::ready(Ok(()))
177    }
178}
179
180pub struct LocalLanguageModel {
181    state: Entity<State>,
182    request_limiter: RateLimiter,
183}
184
185impl LocalLanguageModel {
186    fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
187        let mut messages = TextMessages::new();
188
189        for message in &request.messages {
190            let mut text_content = String::new();
191
192            for content in &message.content {
193                match content {
194                    MessageContent::Text(text) => {
195                        text_content.push_str(text);
196                    }
197                    MessageContent::Image { .. } => {
198                        // For now, skip image content
199                        continue;
200                    }
201                    MessageContent::ToolResult { .. } => {
202                        // Skip tool results for now
203                        continue;
204                    }
205                    MessageContent::Thinking { .. } => {
206                        // Skip thinking content
207                        continue;
208                    }
209                    MessageContent::RedactedThinking(_) => {
210                        // Skip redacted thinking
211                        continue;
212                    }
213                    MessageContent::ToolUse(_) => {
214                        // Skip tool use
215                        continue;
216                    }
217                }
218            }
219
220            if text_content.is_empty() {
221                continue;
222            }
223
224            let role = match message.role {
225                Role::User => TextMessageRole::User,
226                Role::Assistant => TextMessageRole::Assistant,
227                Role::System => TextMessageRole::System,
228            };
229
230            messages = messages.add_message(role, text_content);
231        }
232
233        messages
234    }
235}
236
237impl LanguageModel for LocalLanguageModel {
238    fn id(&self) -> LanguageModelId {
239        LanguageModelId(DEFAULT_MODEL.into())
240    }
241
242    fn name(&self) -> LanguageModelName {
243        LanguageModelName(DEFAULT_MODEL.into())
244    }
245
246    fn provider_id(&self) -> LanguageModelProviderId {
247        PROVIDER_ID
248    }
249
250    fn provider_name(&self) -> LanguageModelProviderName {
251        PROVIDER_NAME
252    }
253
254    fn telemetry_id(&self) -> String {
255        format!("local/{}", DEFAULT_MODEL)
256    }
257
258    fn supports_tools(&self) -> bool {
259        false
260    }
261
262    fn supports_images(&self) -> bool {
263        false
264    }
265
266    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
267        false
268    }
269
270    fn max_token_count(&self) -> u64 {
271        128000 // GLM-4.5-Air supports 128k context
272    }
273
274    fn count_tokens(
275        &self,
276        request: LanguageModelRequest,
277        _cx: &App,
278    ) -> BoxFuture<'static, Result<u64>> {
279        // Rough estimation: 1 token ≈ 4 characters
280        let mut total_chars = 0;
281        for message in request.messages {
282            for content in message.content {
283                match content {
284                    MessageContent::Text(text) => total_chars += text.len(),
285                    _ => {}
286                }
287            }
288        }
289        let tokens = (total_chars / 4) as u64;
290        futures::future::ready(Ok(tokens)).boxed()
291    }
292
293    fn stream_completion(
294        &self,
295        request: LanguageModelRequest,
296        cx: &AsyncApp,
297    ) -> BoxFuture<
298        'static,
299        Result<
300            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
301            LanguageModelCompletionError,
302        >,
303    > {
304        let messages = self.to_mistral_messages(&request);
305        let state = self.state.clone();
306        let limiter = self.request_limiter.clone();
307
308        cx.spawn(async move |cx| {
309            let result: Result<
310                BoxStream<
311                    'static,
312                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
313                >,
314                LanguageModelCompletionError,
315            > = limiter
316                .run(async move {
317                    let model = cx
318                        .read_entity(&state, |state, _| state.model.clone())
319                        .map_err(|_| {
320                            LanguageModelCompletionError::Other(anyhow!("App state dropped"))
321                        })?
322                        .ok_or_else(|| {
323                            LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
324                        })?;
325
326                    let (mut tx, rx) = mpsc::channel(32);
327
328                    // Spawn a task to handle the stream
329                    let _ = smol::spawn(async move {
330                        let mut stream = match model.stream_chat_request(messages).await {
331                            Ok(stream) => stream,
332                            Err(e) => {
333                                let _ = tx
334                                    .send(Err(LanguageModelCompletionError::Other(anyhow!(
335                                        "Failed to start stream: {}",
336                                        e
337                                    ))))
338                                    .await;
339                                return;
340                            }
341                        };
342
343                        while let Some(response) = stream.next().await {
344                            let event = match response {
345                                MistralResponse::Chunk(chunk) => {
346                                    if let Some(choice) = chunk.choices.first() {
347                                        if let Some(content) = &choice.delta.content {
348                                            Some(Ok(LanguageModelCompletionEvent::Text(
349                                                content.clone(),
350                                            )))
351                                        } else if let Some(finish_reason) = &choice.finish_reason {
352                                            let stop_reason = match finish_reason.as_str() {
353                                                "stop" => StopReason::EndTurn,
354                                                "length" => StopReason::MaxTokens,
355                                                _ => StopReason::EndTurn,
356                                            };
357                                            Some(Ok(LanguageModelCompletionEvent::Stop(
358                                                stop_reason,
359                                            )))
360                                        } else {
361                                            None
362                                        }
363                                    } else {
364                                        None
365                                    }
366                                }
367                                MistralResponse::Done(_response) => {
368                                    // For now, we don't emit usage events since the format doesn't match
369                                    None
370                                }
371                                _ => None,
372                            };
373
374                            if let Some(event) = event {
375                                if tx.send(event).await.is_err() {
376                                    break;
377                                }
378                            }
379                        }
380                    })
381                    .detach();
382
383                    Ok(rx.boxed())
384                })
385                .await;
386
387            result
388        })
389        .boxed()
390    }
391}
392
393struct ConfigurationView {
394    state: Entity<State>,
395}
396
397impl Render for ConfigurationView {
398    fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
399        let status = self.state.read(cx).status.clone();
400
401        div().size_full().child(
402            div()
403                .p_4()
404                .child(
405                    div()
406                        .flex()
407                        .gap_2()
408                        .items_center()
409                        .child(match &status {
410                            ModelStatus::NotLoaded => Label::new("Model not loaded"),
411                            ModelStatus::Loading => Label::new("Loading model..."),
412                            ModelStatus::Loaded => Label::new("Model loaded"),
413                            ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
414                        })
415                        .child(match &status {
416                            ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
417                            ModelStatus::Loading => Indicator::dot().color(Color::Modified),
418                            ModelStatus::Loaded => Indicator::dot().color(Color::Success),
419                            ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
420                        }),
421                )
422                .when(!matches!(status, ModelStatus::Loading), |this| {
423                    this.child(
424                        ButtonLike::new("load_model")
425                            .child(Label::new(if matches!(status, ModelStatus::Loaded) {
426                                "Reload Model"
427                            } else {
428                                "Load Model"
429                            }))
430                            .on_click(cx.listener(|this, _, _window, cx| {
431                                this.state.update(cx, |state, cx| {
432                                    state.authenticate(cx).detach();
433                                });
434                            })),
435                    )
436                }),
437        )
438    }
439}
440
441#[cfg(test)]
442mod tests;