local.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  8    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  9    LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
 10};
 11use mistralrs::{
 12    IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
 13    TextModelBuilder,
 14};
 15use serde::{Deserialize, Serialize};
 16use std::sync::Arc;
 17use ui::{ButtonLike, IconName, Indicator, prelude::*};
 18
 19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
 20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
 21const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
 22
 23#[derive(Default, Debug, Clone, PartialEq)]
 24pub struct LocalSettings {
 25    pub available_models: Vec<AvailableModel>,
 26}
 27
 28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
 29pub struct AvailableModel {
 30    pub name: String,
 31    pub display_name: Option<String>,
 32    pub max_tokens: u64,
 33}
 34
 35pub struct LocalLanguageModelProvider {
 36    state: Entity<State>,
 37}
 38
 39pub struct State {
 40    model: Option<Arc<MistralModel>>,
 41    status: ModelStatus,
 42}
 43
 44#[derive(Clone, Debug, PartialEq)]
 45enum ModelStatus {
 46    NotLoaded,
 47    Loading,
 48    Loaded,
 49    Error(String),
 50}
 51
 52impl State {
 53    fn new(_cx: &mut Context<Self>) -> Self {
 54        Self {
 55            model: None,
 56            status: ModelStatus::NotLoaded,
 57        }
 58    }
 59
 60    fn is_authenticated(&self) -> bool {
 61        // Local models don't require authentication
 62        true
 63    }
 64
 65    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 66        // Skip if already loaded or currently loading
 67        if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
 68            return Task::ready(Ok(()));
 69        }
 70
 71        self.status = ModelStatus::Loading;
 72        cx.notify();
 73
 74        let background_executor = cx.background_executor().clone();
 75        cx.spawn(async move |this, cx| {
 76            eprintln!("Local model: Starting to load model");
 77
 78            // Move the model loading to a background thread
 79            let model_result = background_executor
 80                .spawn(async move { load_mistral_model().await })
 81                .await;
 82
 83            match model_result {
 84                Ok(model) => {
 85                    eprintln!("Local model: Model loaded successfully");
 86                    this.update(cx, |state, cx| {
 87                        state.model = Some(model);
 88                        state.status = ModelStatus::Loaded;
 89                        cx.notify();
 90                        eprintln!("Local model: Status updated to Loaded");
 91                    })?;
 92                    Ok(())
 93                }
 94                Err(e) => {
 95                    let error_msg = e.to_string();
 96                    eprintln!("Local model: Failed to load model - {}", error_msg);
 97                    this.update(cx, |state, cx| {
 98                        state.status = ModelStatus::Error(error_msg.clone());
 99                        cx.notify();
100                        eprintln!("Local model: Status updated to Failed");
101                    })?;
102                    Err(AuthenticateError::Other(anyhow!(
103                        "Failed to load model: {}",
104                        error_msg
105                    )))
106                }
107            }
108        })
109    }
110}
111
112async fn load_mistral_model() -> Result<Arc<MistralModel>> {
113    println!("\n\n\n\nLoading mistral model...\n\n\n");
114    eprintln!("Starting to load model: {}", DEFAULT_MODEL);
115
116    // Configure the model builder to use background threads for downloads
117    eprintln!("Creating TextModelBuilder...");
118    let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
119
120    eprintln!("Building model (this should be quick for a 0.5B model)...");
121    let start_time = std::time::Instant::now();
122
123    match builder.build().await {
124        Ok(model) => {
125            let elapsed = start_time.elapsed();
126            eprintln!("Model loaded successfully in {:?}", elapsed);
127            Ok(Arc::new(model))
128        }
129        Err(e) => {
130            eprintln!("Failed to load model: {:?}", e);
131            Err(e)
132        }
133    }
134}
135
136impl LocalLanguageModelProvider {
137    pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
138        let state = cx.new(State::new);
139        Self { state }
140    }
141}
142
143impl LanguageModelProviderState for LocalLanguageModelProvider {
144    type ObservableEntity = State;
145
146    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
147        Some(self.state.clone())
148    }
149}
150
151impl LanguageModelProvider for LocalLanguageModelProvider {
152    fn id(&self) -> LanguageModelProviderId {
153        PROVIDER_ID
154    }
155
156    fn name(&self) -> LanguageModelProviderName {
157        PROVIDER_NAME
158    }
159
160    fn icon(&self) -> IconName {
161        IconName::Ai
162    }
163
164    fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        vec![Arc::new(LocalLanguageModel {
166            state: self.state.clone(),
167            request_limiter: RateLimiter::new(4),
168        })]
169    }
170
171    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
172        self.provided_models(cx).into_iter().next()
173    }
174
175    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
176        self.default_model(cx)
177    }
178
179    fn is_authenticated(&self, _cx: &App) -> bool {
180        // Local models don't require authentication
181        true
182    }
183
184    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
185        self.state.update(cx, |state, cx| state.authenticate(cx))
186    }
187
188    fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
189        cx.new(|_cx| ConfigurationView {
190            state: self.state.clone(),
191        })
192        .into()
193    }
194
195    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
196        self.state.update(cx, |state, cx| {
197            state.model = None;
198            state.status = ModelStatus::NotLoaded;
199            cx.notify();
200        });
201        Task::ready(Ok(()))
202    }
203}
204
205pub struct LocalLanguageModel {
206    state: Entity<State>,
207    request_limiter: RateLimiter,
208}
209
210impl LocalLanguageModel {
211    fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
212        let mut messages = TextMessages::new();
213
214        for message in &request.messages {
215            let mut text_content = String::new();
216
217            for content in &message.content {
218                match content {
219                    MessageContent::Text(text) => {
220                        text_content.push_str(text);
221                    }
222                    MessageContent::Image { .. } => {
223                        // For now, skip image content
224                        continue;
225                    }
226                    MessageContent::ToolResult { .. } => {
227                        // Skip tool results for now
228                        continue;
229                    }
230                    MessageContent::Thinking { .. } => {
231                        // Skip thinking content
232                        continue;
233                    }
234                    MessageContent::RedactedThinking(_) => {
235                        // Skip redacted thinking
236                        continue;
237                    }
238                    MessageContent::ToolUse(_) => {
239                        // Skip tool use
240                        continue;
241                    }
242                }
243            }
244
245            if text_content.is_empty() {
246                continue;
247            }
248
249            let role = match message.role {
250                Role::User => TextMessageRole::User,
251                Role::Assistant => TextMessageRole::Assistant,
252                Role::System => TextMessageRole::System,
253            };
254
255            messages = messages.add_message(role, text_content);
256        }
257
258        messages
259    }
260}
261
262impl LanguageModel for LocalLanguageModel {
263    fn id(&self) -> LanguageModelId {
264        LanguageModelId(DEFAULT_MODEL.into())
265    }
266
267    fn name(&self) -> LanguageModelName {
268        LanguageModelName(DEFAULT_MODEL.into())
269    }
270
271    fn provider_id(&self) -> LanguageModelProviderId {
272        PROVIDER_ID
273    }
274
275    fn provider_name(&self) -> LanguageModelProviderName {
276        PROVIDER_NAME
277    }
278
279    fn telemetry_id(&self) -> String {
280        format!("local/{}", DEFAULT_MODEL)
281    }
282
283    fn supports_tools(&self) -> bool {
284        true
285    }
286
287    fn supports_images(&self) -> bool {
288        false
289    }
290
291    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
292        true
293    }
294
295    fn max_token_count(&self) -> u64 {
296        128000 // Qwen2.5 supports 128k context
297    }
298
299    fn count_tokens(
300        &self,
301        request: LanguageModelRequest,
302        _cx: &App,
303    ) -> BoxFuture<'static, Result<u64>> {
304        // Rough estimation: 1 token ≈ 4 characters
305        let mut total_chars = 0;
306        for message in request.messages {
307            for content in message.content {
308                match content {
309                    MessageContent::Text(text) => total_chars += text.len(),
310                    _ => {}
311                }
312            }
313        }
314        let tokens = (total_chars / 4) as u64;
315        futures::future::ready(Ok(tokens)).boxed()
316    }
317
318    fn stream_completion(
319        &self,
320        request: LanguageModelRequest,
321        cx: &AsyncApp,
322    ) -> BoxFuture<
323        'static,
324        Result<
325            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
326            LanguageModelCompletionError,
327        >,
328    > {
329        let messages = self.to_mistral_messages(&request);
330        let state = self.state.clone();
331        let limiter = self.request_limiter.clone();
332
333        cx.spawn(async move |cx| {
334            let result: Result<
335                BoxStream<
336                    'static,
337                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
338                >,
339                LanguageModelCompletionError,
340            > = limiter
341                .run(async move {
342                    let model = cx
343                        .read_entity(&state, |state, _| {
344                            eprintln!(
345                                "Local model: Checking if model is loaded: {:?}",
346                                state.status
347                            );
348                            state.model.clone()
349                        })
350                        .map_err(|_| {
351                            LanguageModelCompletionError::Other(anyhow!("App state dropped"))
352                        })?
353                        .ok_or_else(|| {
354                            eprintln!("Local model: Model is not loaded!");
355                            LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
356                        })?;
357
358                    let (mut tx, rx) = mpsc::channel(32);
359
360                    // Spawn a task to handle the stream
361                    let _ = smol::spawn(async move {
362                        let mut stream = match model.stream_chat_request(messages).await {
363                            Ok(stream) => stream,
364                            Err(e) => {
365                                let _ = tx
366                                    .send(Err(LanguageModelCompletionError::Other(anyhow!(
367                                        "Failed to start stream: {}",
368                                        e
369                                    ))))
370                                    .await;
371                                return;
372                            }
373                        };
374
375                        while let Some(response) = stream.next().await {
376                            let event = match response {
377                                MistralResponse::Chunk(chunk) => {
378                                    if let Some(choice) = chunk.choices.first() {
379                                        if let Some(content) = &choice.delta.content {
380                                            Some(Ok(LanguageModelCompletionEvent::Text(
381                                                content.clone(),
382                                            )))
383                                        } else if let Some(finish_reason) = &choice.finish_reason {
384                                            let stop_reason = match finish_reason.as_str() {
385                                                "stop" => StopReason::EndTurn,
386                                                "length" => StopReason::MaxTokens,
387                                                _ => StopReason::EndTurn,
388                                            };
389                                            Some(Ok(LanguageModelCompletionEvent::Stop(
390                                                stop_reason,
391                                            )))
392                                        } else {
393                                            None
394                                        }
395                                    } else {
396                                        None
397                                    }
398                                }
399                                MistralResponse::Done(_response) => {
400                                    // For now, we don't emit usage events since the format doesn't match
401                                    None
402                                }
403                                _ => None,
404                            };
405
406                            if let Some(event) = event {
407                                if tx.send(event).await.is_err() {
408                                    break;
409                                }
410                            }
411                        }
412                    })
413                    .detach();
414
415                    Ok(rx.boxed())
416                })
417                .await;
418
419            result
420        })
421        .boxed()
422    }
423}
424
425struct ConfigurationView {
426    state: Entity<State>,
427}
428
429impl Render for ConfigurationView {
430    fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
431        let status = self.state.read(cx).status.clone();
432
433        div().size_full().child(
434            div()
435                .p_4()
436                .child(
437                    div()
438                        .flex()
439                        .gap_2()
440                        .items_center()
441                        .child(match &status {
442                            ModelStatus::NotLoaded => Label::new("Model not loaded"),
443                            ModelStatus::Loading => Label::new("Loading model..."),
444                            ModelStatus::Loaded => Label::new("Model loaded"),
445                            ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
446                        })
447                        .child(match &status {
448                            ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
449                            ModelStatus::Loading => Indicator::dot().color(Color::Modified),
450                            ModelStatus::Loaded => Indicator::dot().color(Color::Success),
451                            ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
452                        }),
453                )
454                .when(!matches!(status, ModelStatus::Loading), |this| {
455                    this.child(
456                        ButtonLike::new("load_model")
457                            .child(Label::new(if matches!(status, ModelStatus::Loaded) {
458                                "Reload Model"
459                            } else {
460                                "Load Model"
461                            }))
462                            .on_click(cx.listener(|this, _, _window, cx| {
463                                this.state.update(cx, |state, cx| {
464                                    state.authenticate(cx).detach();
465                                });
466                            })),
467                    )
468                }),
469        )
470    }
471}
472
473#[cfg(test)]
474mod tests;