1use anyhow::{Result, anyhow};
2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
9 LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
10};
11use mistralrs::{
12 IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
13 TextModelBuilder,
14};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use ui::{ButtonLike, IconName, Indicator, prelude::*};
18
19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
21const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
22
23#[derive(Default, Debug, Clone, PartialEq)]
24pub struct LocalSettings {
25 pub available_models: Vec<AvailableModel>,
26}
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct AvailableModel {
30 pub name: String,
31 pub display_name: Option<String>,
32 pub max_tokens: u64,
33}
34
35pub struct LocalLanguageModelProvider {
36 state: Entity<State>,
37}
38
39pub struct State {
40 model: Option<Arc<MistralModel>>,
41 status: ModelStatus,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45enum ModelStatus {
46 NotLoaded,
47 Loading,
48 Loaded,
49 Error(String),
50}
51
52impl State {
53 fn new(_cx: &mut Context<Self>) -> Self {
54 Self {
55 model: None,
56 status: ModelStatus::NotLoaded,
57 }
58 }
59
60 fn is_authenticated(&self) -> bool {
61 // Local models don't require authentication
62 true
63 }
64
65 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
66 // Skip if already loaded or currently loading
67 if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
68 return Task::ready(Ok(()));
69 }
70
71 self.status = ModelStatus::Loading;
72 cx.notify();
73
74 let background_executor = cx.background_executor().clone();
75 cx.spawn(async move |this, cx| {
76 eprintln!("Local model: Starting to load model");
77
78 // Move the model loading to a background thread
79 let model_result = background_executor
80 .spawn(async move { load_mistral_model().await })
81 .await;
82
83 match model_result {
84 Ok(model) => {
85 eprintln!("Local model: Model loaded successfully");
86 this.update(cx, |state, cx| {
87 state.model = Some(model);
88 state.status = ModelStatus::Loaded;
89 cx.notify();
90 eprintln!("Local model: Status updated to Loaded");
91 })?;
92 Ok(())
93 }
94 Err(e) => {
95 let error_msg = e.to_string();
96 eprintln!("Local model: Failed to load model - {}", error_msg);
97 this.update(cx, |state, cx| {
98 state.status = ModelStatus::Error(error_msg.clone());
99 cx.notify();
100 eprintln!("Local model: Status updated to Failed");
101 })?;
102 Err(AuthenticateError::Other(anyhow!(
103 "Failed to load model: {}",
104 error_msg
105 )))
106 }
107 }
108 })
109 }
110}
111
112async fn load_mistral_model() -> Result<Arc<MistralModel>> {
113 println!("\n\n\n\nLoading mistral model...\n\n\n");
114 eprintln!("Starting to load model: {}", DEFAULT_MODEL);
115
116 // Configure the model builder to use background threads for downloads
117 eprintln!("Creating TextModelBuilder...");
118 let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
119
120 eprintln!("Building model (this should be quick for a 0.5B model)...");
121 let start_time = std::time::Instant::now();
122
123 match builder.build().await {
124 Ok(model) => {
125 let elapsed = start_time.elapsed();
126 eprintln!("Model loaded successfully in {:?}", elapsed);
127 Ok(Arc::new(model))
128 }
129 Err(e) => {
130 eprintln!("Failed to load model: {:?}", e);
131 Err(e)
132 }
133 }
134}
135
136impl LocalLanguageModelProvider {
137 pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
138 let state = cx.new(State::new);
139 Self { state }
140 }
141}
142
143impl LanguageModelProviderState for LocalLanguageModelProvider {
144 type ObservableEntity = State;
145
146 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
147 Some(self.state.clone())
148 }
149}
150
151impl LanguageModelProvider for LocalLanguageModelProvider {
152 fn id(&self) -> LanguageModelProviderId {
153 PROVIDER_ID
154 }
155
156 fn name(&self) -> LanguageModelProviderName {
157 PROVIDER_NAME
158 }
159
160 fn icon(&self) -> IconName {
161 IconName::Ai
162 }
163
164 fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 vec![Arc::new(LocalLanguageModel {
166 state: self.state.clone(),
167 request_limiter: RateLimiter::new(4),
168 })]
169 }
170
171 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
172 self.provided_models(cx).into_iter().next()
173 }
174
175 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
176 self.default_model(cx)
177 }
178
179 fn is_authenticated(&self, _cx: &App) -> bool {
180 // Local models don't require authentication
181 true
182 }
183
184 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
185 self.state.update(cx, |state, cx| state.authenticate(cx))
186 }
187
188 fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
189 cx.new(|_cx| ConfigurationView {
190 state: self.state.clone(),
191 })
192 .into()
193 }
194
195 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
196 self.state.update(cx, |state, cx| {
197 state.model = None;
198 state.status = ModelStatus::NotLoaded;
199 cx.notify();
200 });
201 Task::ready(Ok(()))
202 }
203}
204
205pub struct LocalLanguageModel {
206 state: Entity<State>,
207 request_limiter: RateLimiter,
208}
209
210impl LocalLanguageModel {
211 fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
212 let mut messages = TextMessages::new();
213
214 for message in &request.messages {
215 let mut text_content = String::new();
216
217 for content in &message.content {
218 match content {
219 MessageContent::Text(text) => {
220 text_content.push_str(text);
221 }
222 MessageContent::Image { .. } => {
223 // For now, skip image content
224 continue;
225 }
226 MessageContent::ToolResult { .. } => {
227 // Skip tool results for now
228 continue;
229 }
230 MessageContent::Thinking { .. } => {
231 // Skip thinking content
232 continue;
233 }
234 MessageContent::RedactedThinking(_) => {
235 // Skip redacted thinking
236 continue;
237 }
238 MessageContent::ToolUse(_) => {
239 // Skip tool use
240 continue;
241 }
242 }
243 }
244
245 if text_content.is_empty() {
246 continue;
247 }
248
249 let role = match message.role {
250 Role::User => TextMessageRole::User,
251 Role::Assistant => TextMessageRole::Assistant,
252 Role::System => TextMessageRole::System,
253 };
254
255 messages = messages.add_message(role, text_content);
256 }
257
258 messages
259 }
260}
261
262impl LanguageModel for LocalLanguageModel {
263 fn id(&self) -> LanguageModelId {
264 LanguageModelId(DEFAULT_MODEL.into())
265 }
266
267 fn name(&self) -> LanguageModelName {
268 LanguageModelName(DEFAULT_MODEL.into())
269 }
270
271 fn provider_id(&self) -> LanguageModelProviderId {
272 PROVIDER_ID
273 }
274
275 fn provider_name(&self) -> LanguageModelProviderName {
276 PROVIDER_NAME
277 }
278
279 fn telemetry_id(&self) -> String {
280 format!("local/{}", DEFAULT_MODEL)
281 }
282
283 fn supports_tools(&self) -> bool {
284 true
285 }
286
287 fn supports_images(&self) -> bool {
288 false
289 }
290
291 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
292 true
293 }
294
295 fn max_token_count(&self) -> u64 {
296 128000 // Qwen2.5 supports 128k context
297 }
298
299 fn count_tokens(
300 &self,
301 request: LanguageModelRequest,
302 _cx: &App,
303 ) -> BoxFuture<'static, Result<u64>> {
304 // Rough estimation: 1 token ≈ 4 characters
305 let mut total_chars = 0;
306 for message in request.messages {
307 for content in message.content {
308 match content {
309 MessageContent::Text(text) => total_chars += text.len(),
310 _ => {}
311 }
312 }
313 }
314 let tokens = (total_chars / 4) as u64;
315 futures::future::ready(Ok(tokens)).boxed()
316 }
317
318 fn stream_completion(
319 &self,
320 request: LanguageModelRequest,
321 cx: &AsyncApp,
322 ) -> BoxFuture<
323 'static,
324 Result<
325 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
326 LanguageModelCompletionError,
327 >,
328 > {
329 let messages = self.to_mistral_messages(&request);
330 let state = self.state.clone();
331 let limiter = self.request_limiter.clone();
332
333 cx.spawn(async move |cx| {
334 let result: Result<
335 BoxStream<
336 'static,
337 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
338 >,
339 LanguageModelCompletionError,
340 > = limiter
341 .run(async move {
342 let model = cx
343 .read_entity(&state, |state, _| {
344 eprintln!(
345 "Local model: Checking if model is loaded: {:?}",
346 state.status
347 );
348 state.model.clone()
349 })
350 .map_err(|_| {
351 LanguageModelCompletionError::Other(anyhow!("App state dropped"))
352 })?
353 .ok_or_else(|| {
354 eprintln!("Local model: Model is not loaded!");
355 LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
356 })?;
357
358 let (mut tx, rx) = mpsc::channel(32);
359
360 // Spawn a task to handle the stream
361 let _ = smol::spawn(async move {
362 let mut stream = match model.stream_chat_request(messages).await {
363 Ok(stream) => stream,
364 Err(e) => {
365 let _ = tx
366 .send(Err(LanguageModelCompletionError::Other(anyhow!(
367 "Failed to start stream: {}",
368 e
369 ))))
370 .await;
371 return;
372 }
373 };
374
375 while let Some(response) = stream.next().await {
376 let event = match response {
377 MistralResponse::Chunk(chunk) => {
378 if let Some(choice) = chunk.choices.first() {
379 if let Some(content) = &choice.delta.content {
380 Some(Ok(LanguageModelCompletionEvent::Text(
381 content.clone(),
382 )))
383 } else if let Some(finish_reason) = &choice.finish_reason {
384 let stop_reason = match finish_reason.as_str() {
385 "stop" => StopReason::EndTurn,
386 "length" => StopReason::MaxTokens,
387 _ => StopReason::EndTurn,
388 };
389 Some(Ok(LanguageModelCompletionEvent::Stop(
390 stop_reason,
391 )))
392 } else {
393 None
394 }
395 } else {
396 None
397 }
398 }
399 MistralResponse::Done(_response) => {
400 // For now, we don't emit usage events since the format doesn't match
401 None
402 }
403 _ => None,
404 };
405
406 if let Some(event) = event {
407 if tx.send(event).await.is_err() {
408 break;
409 }
410 }
411 }
412 })
413 .detach();
414
415 Ok(rx.boxed())
416 })
417 .await;
418
419 result
420 })
421 .boxed()
422 }
423}
424
425struct ConfigurationView {
426 state: Entity<State>,
427}
428
429impl Render for ConfigurationView {
430 fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
431 let status = self.state.read(cx).status.clone();
432
433 div().size_full().child(
434 div()
435 .p_4()
436 .child(
437 div()
438 .flex()
439 .gap_2()
440 .items_center()
441 .child(match &status {
442 ModelStatus::NotLoaded => Label::new("Model not loaded"),
443 ModelStatus::Loading => Label::new("Loading model..."),
444 ModelStatus::Loaded => Label::new("Model loaded"),
445 ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
446 })
447 .child(match &status {
448 ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
449 ModelStatus::Loading => Indicator::dot().color(Color::Modified),
450 ModelStatus::Loaded => Indicator::dot().color(Color::Success),
451 ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
452 }),
453 )
454 .when(!matches!(status, ModelStatus::Loading), |this| {
455 this.child(
456 ButtonLike::new("load_model")
457 .child(Label::new(if matches!(status, ModelStatus::Loaded) {
458 "Reload Model"
459 } else {
460 "Load Model"
461 }))
462 .on_click(cx.listener(|this, _, _window, cx| {
463 this.state.update(cx, |state, cx| {
464 state.authenticate(cx).detach();
465 });
466 })),
467 )
468 }),
469 )
470 }
471}
472
473#[cfg(test)]
474mod tests;