1use anyhow::{Result, anyhow};
2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
9 LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
10};
11use mistralrs::{
12 IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
13 TextModelBuilder,
14};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use ui::{ButtonLike, IconName, Indicator, prelude::*};
18
19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
21const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
22
23#[derive(Default, Debug, Clone, PartialEq)]
24pub struct LocalSettings {
25 pub available_models: Vec<AvailableModel>,
26}
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct AvailableModel {
30 pub name: String,
31 pub display_name: Option<String>,
32 pub max_tokens: u64,
33}
34
35pub struct LocalLanguageModelProvider {
36 state: Entity<State>,
37}
38
39pub struct State {
40 model: Option<Arc<MistralModel>>,
41 status: ModelStatus,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45enum ModelStatus {
46 NotLoaded,
47 Loading,
48 Loaded,
49 Error(String),
50}
51
52impl State {
53 fn new(_cx: &mut Context<Self>) -> Self {
54 Self {
55 model: None,
56 status: ModelStatus::NotLoaded,
57 }
58 }
59
60 fn is_authenticated(&self) -> bool {
61 // Local models don't require authentication
62 true
63 }
64
65 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
66 if self.is_authenticated() {
67 return Task::ready(Ok(()));
68 }
69
70 if matches!(self.status, ModelStatus::Loading) {
71 return Task::ready(Ok(()));
72 }
73
74 self.status = ModelStatus::Loading;
75 cx.notify();
76
77 cx.spawn(async move |this, cx| match load_mistral_model().await {
78 Ok(model) => {
79 this.update(cx, |state, cx| {
80 state.model = Some(model);
81 state.status = ModelStatus::Loaded;
82 cx.notify();
83 })?;
84 Ok(())
85 }
86 Err(e) => {
87 let error_msg = e.to_string();
88 this.update(cx, |state, cx| {
89 state.status = ModelStatus::Error(error_msg.clone());
90 cx.notify();
91 })?;
92 Err(AuthenticateError::Other(anyhow!(
93 "Failed to load model: {}",
94 error_msg
95 )))
96 }
97 })
98 }
99}
100
101async fn load_mistral_model() -> Result<Arc<MistralModel>> {
102 println!("\n\n\n\nLoading mistral model...\n\n\n");
103 let model = TextModelBuilder::new(DEFAULT_MODEL)
104 .with_isq(IsqType::Q4_0)
105 .build()
106 .await?;
107
108 Ok(Arc::new(model))
109}
110
111impl LocalLanguageModelProvider {
112 pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
113 let state = cx.new(State::new);
114 Self { state }
115 }
116}
117
118impl LanguageModelProviderState for LocalLanguageModelProvider {
119 type ObservableEntity = State;
120
121 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
122 Some(self.state.clone())
123 }
124}
125
126impl LanguageModelProvider for LocalLanguageModelProvider {
127 fn id(&self) -> LanguageModelProviderId {
128 PROVIDER_ID
129 }
130
131 fn name(&self) -> LanguageModelProviderName {
132 PROVIDER_NAME
133 }
134
135 fn icon(&self) -> IconName {
136 IconName::Ai
137 }
138
139 fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
140 vec![Arc::new(LocalLanguageModel {
141 state: self.state.clone(),
142 request_limiter: RateLimiter::new(4),
143 })]
144 }
145
146 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
147 self.provided_models(cx).into_iter().next()
148 }
149
150 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
151 self.default_model(cx)
152 }
153
154 fn is_authenticated(&self, _cx: &App) -> bool {
155 // Local models don't require authentication
156 true
157 }
158
159 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
160 self.state.update(cx, |state, cx| state.authenticate(cx))
161 }
162
163 fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
164 cx.new(|_cx| ConfigurationView {
165 state: self.state.clone(),
166 })
167 .into()
168 }
169
170 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
171 self.state.update(cx, |state, cx| {
172 state.model = None;
173 state.status = ModelStatus::NotLoaded;
174 cx.notify();
175 });
176 Task::ready(Ok(()))
177 }
178}
179
180pub struct LocalLanguageModel {
181 state: Entity<State>,
182 request_limiter: RateLimiter,
183}
184
185impl LocalLanguageModel {
186 fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
187 let mut messages = TextMessages::new();
188
189 for message in &request.messages {
190 let mut text_content = String::new();
191
192 for content in &message.content {
193 match content {
194 MessageContent::Text(text) => {
195 text_content.push_str(text);
196 }
197 MessageContent::Image { .. } => {
198 // For now, skip image content
199 continue;
200 }
201 MessageContent::ToolResult { .. } => {
202 // Skip tool results for now
203 continue;
204 }
205 MessageContent::Thinking { .. } => {
206 // Skip thinking content
207 continue;
208 }
209 MessageContent::RedactedThinking(_) => {
210 // Skip redacted thinking
211 continue;
212 }
213 MessageContent::ToolUse(_) => {
214 // Skip tool use
215 continue;
216 }
217 }
218 }
219
220 if text_content.is_empty() {
221 continue;
222 }
223
224 let role = match message.role {
225 Role::User => TextMessageRole::User,
226 Role::Assistant => TextMessageRole::Assistant,
227 Role::System => TextMessageRole::System,
228 };
229
230 messages = messages.add_message(role, text_content);
231 }
232
233 messages
234 }
235}
236
237impl LanguageModel for LocalLanguageModel {
238 fn id(&self) -> LanguageModelId {
239 LanguageModelId(DEFAULT_MODEL.into())
240 }
241
242 fn name(&self) -> LanguageModelName {
243 LanguageModelName(DEFAULT_MODEL.into())
244 }
245
246 fn provider_id(&self) -> LanguageModelProviderId {
247 PROVIDER_ID
248 }
249
250 fn provider_name(&self) -> LanguageModelProviderName {
251 PROVIDER_NAME
252 }
253
254 fn telemetry_id(&self) -> String {
255 format!("local/{}", DEFAULT_MODEL)
256 }
257
258 fn supports_tools(&self) -> bool {
259 false
260 }
261
262 fn supports_images(&self) -> bool {
263 false
264 }
265
266 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
267 false
268 }
269
270 fn max_token_count(&self) -> u64 {
271 128000 // GLM-4.5-Air supports 128k context
272 }
273
274 fn count_tokens(
275 &self,
276 request: LanguageModelRequest,
277 _cx: &App,
278 ) -> BoxFuture<'static, Result<u64>> {
279 // Rough estimation: 1 token ≈ 4 characters
280 let mut total_chars = 0;
281 for message in request.messages {
282 for content in message.content {
283 match content {
284 MessageContent::Text(text) => total_chars += text.len(),
285 _ => {}
286 }
287 }
288 }
289 let tokens = (total_chars / 4) as u64;
290 futures::future::ready(Ok(tokens)).boxed()
291 }
292
293 fn stream_completion(
294 &self,
295 request: LanguageModelRequest,
296 cx: &AsyncApp,
297 ) -> BoxFuture<
298 'static,
299 Result<
300 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
301 LanguageModelCompletionError,
302 >,
303 > {
304 let messages = self.to_mistral_messages(&request);
305 let state = self.state.clone();
306 let limiter = self.request_limiter.clone();
307
308 cx.spawn(async move |cx| {
309 let result: Result<
310 BoxStream<
311 'static,
312 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
313 >,
314 LanguageModelCompletionError,
315 > = limiter
316 .run(async move {
317 let model = cx
318 .read_entity(&state, |state, _| state.model.clone())
319 .map_err(|_| {
320 LanguageModelCompletionError::Other(anyhow!("App state dropped"))
321 })?
322 .ok_or_else(|| {
323 LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
324 })?;
325
326 let (mut tx, rx) = mpsc::channel(32);
327
328 // Spawn a task to handle the stream
329 let _ = smol::spawn(async move {
330 let mut stream = match model.stream_chat_request(messages).await {
331 Ok(stream) => stream,
332 Err(e) => {
333 let _ = tx
334 .send(Err(LanguageModelCompletionError::Other(anyhow!(
335 "Failed to start stream: {}",
336 e
337 ))))
338 .await;
339 return;
340 }
341 };
342
343 while let Some(response) = stream.next().await {
344 let event = match response {
345 MistralResponse::Chunk(chunk) => {
346 if let Some(choice) = chunk.choices.first() {
347 if let Some(content) = &choice.delta.content {
348 Some(Ok(LanguageModelCompletionEvent::Text(
349 content.clone(),
350 )))
351 } else if let Some(finish_reason) = &choice.finish_reason {
352 let stop_reason = match finish_reason.as_str() {
353 "stop" => StopReason::EndTurn,
354 "length" => StopReason::MaxTokens,
355 _ => StopReason::EndTurn,
356 };
357 Some(Ok(LanguageModelCompletionEvent::Stop(
358 stop_reason,
359 )))
360 } else {
361 None
362 }
363 } else {
364 None
365 }
366 }
367 MistralResponse::Done(_response) => {
368 // For now, we don't emit usage events since the format doesn't match
369 None
370 }
371 _ => None,
372 };
373
374 if let Some(event) = event {
375 if tx.send(event).await.is_err() {
376 break;
377 }
378 }
379 }
380 })
381 .detach();
382
383 Ok(rx.boxed())
384 })
385 .await;
386
387 result
388 })
389 .boxed()
390 }
391}
392
393struct ConfigurationView {
394 state: Entity<State>,
395}
396
397impl Render for ConfigurationView {
398 fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
399 let status = self.state.read(cx).status.clone();
400
401 div().size_full().child(
402 div()
403 .p_4()
404 .child(
405 div()
406 .flex()
407 .gap_2()
408 .items_center()
409 .child(match &status {
410 ModelStatus::NotLoaded => Label::new("Model not loaded"),
411 ModelStatus::Loading => Label::new("Loading model..."),
412 ModelStatus::Loaded => Label::new("Model loaded"),
413 ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
414 })
415 .child(match &status {
416 ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
417 ModelStatus::Loading => Indicator::dot().color(Color::Modified),
418 ModelStatus::Loaded => Indicator::dot().color(Color::Success),
419 ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
420 }),
421 )
422 .when(!matches!(status, ModelStatus::Loading), |this| {
423 this.child(
424 ButtonLike::new("load_model")
425 .child(Label::new(if matches!(status, ModelStatus::Loaded) {
426 "Reload Model"
427 } else {
428 "Load Model"
429 }))
430 .on_click(cx.listener(|this, _, _window, cx| {
431 this.state.update(cx, |state, cx| {
432 state.authenticate(cx).detach();
433 });
434 })),
435 )
436 }),
437 )
438 }
439}
440
441#[cfg(test)]
442mod tests;