1use anyhow::{Result, anyhow};
2use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
8 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
9 LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
10};
11use mistralrs::{
12 IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
13 TextModelBuilder,
14};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use ui::{ButtonLike, IconName, Indicator, prelude::*};
18
19const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
20const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
21const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
22
23#[derive(Default, Debug, Clone, PartialEq)]
24pub struct LocalSettings {
25 pub available_models: Vec<AvailableModel>,
26}
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct AvailableModel {
30 pub name: String,
31 pub display_name: Option<String>,
32 pub max_tokens: u64,
33}
34
35pub struct LocalLanguageModelProvider {
36 state: Entity<State>,
37}
38
39pub struct State {
40 model: Option<Arc<MistralModel>>,
41 status: ModelStatus,
42}
43
44#[derive(Clone, Debug, PartialEq)]
45enum ModelStatus {
46 NotLoaded,
47 Loading,
48 Loaded,
49 Error(String),
50}
51
52impl State {
53 fn new(_cx: &mut Context<Self>) -> Self {
54 Self {
55 model: None,
56 status: ModelStatus::NotLoaded,
57 }
58 }
59
60 fn is_authenticated(&self) -> bool {
61 matches!(self.status, ModelStatus::Loaded)
62 }
63
64 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
65 if self.is_authenticated() {
66 return Task::ready(Ok(()));
67 }
68
69 if matches!(self.status, ModelStatus::Loading) {
70 return Task::ready(Ok(()));
71 }
72
73 self.status = ModelStatus::Loading;
74 cx.notify();
75
76 cx.spawn(async move |this, cx| match load_mistral_model().await {
77 Ok(model) => {
78 this.update(cx, |state, cx| {
79 state.model = Some(model);
80 state.status = ModelStatus::Loaded;
81 cx.notify();
82 })?;
83 Ok(())
84 }
85 Err(e) => {
86 let error_msg = e.to_string();
87 this.update(cx, |state, cx| {
88 state.status = ModelStatus::Error(error_msg.clone());
89 cx.notify();
90 })?;
91 Err(AuthenticateError::Other(anyhow!(
92 "Failed to load model: {}",
93 error_msg
94 )))
95 }
96 })
97 }
98}
99
100async fn load_mistral_model() -> Result<Arc<MistralModel>> {
101 let model = TextModelBuilder::new(DEFAULT_MODEL)
102 .with_isq(IsqType::Q4_0)
103 .with_logging()
104 .build()
105 .await?;
106
107 Ok(Arc::new(model))
108}
109
110impl LocalLanguageModelProvider {
111 pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
112 let state = cx.new(State::new);
113 Self { state }
114 }
115}
116
117impl LanguageModelProviderState for LocalLanguageModelProvider {
118 type ObservableEntity = State;
119
120 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
121 Some(self.state.clone())
122 }
123}
124
125impl LanguageModelProvider for LocalLanguageModelProvider {
126 fn id(&self) -> LanguageModelProviderId {
127 PROVIDER_ID
128 }
129
130 fn name(&self) -> LanguageModelProviderName {
131 PROVIDER_NAME
132 }
133
134 fn icon(&self) -> IconName {
135 IconName::Ai
136 }
137
138 fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
139 vec![Arc::new(LocalLanguageModel {
140 state: self.state.clone(),
141 request_limiter: RateLimiter::new(4),
142 })]
143 }
144
145 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
146 self.provided_models(cx).into_iter().next()
147 }
148
149 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
150 self.default_model(cx)
151 }
152
153 fn is_authenticated(&self, cx: &App) -> bool {
154 self.state.read(cx).is_authenticated()
155 }
156
157 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
158 self.state.update(cx, |state, cx| state.authenticate(cx))
159 }
160
161 fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
162 cx.new(|_cx| ConfigurationView {
163 state: self.state.clone(),
164 })
165 .into()
166 }
167
168 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
169 self.state.update(cx, |state, cx| {
170 state.model = None;
171 state.status = ModelStatus::NotLoaded;
172 cx.notify();
173 });
174 Task::ready(Ok(()))
175 }
176}
177
178pub struct LocalLanguageModel {
179 state: Entity<State>,
180 request_limiter: RateLimiter,
181}
182
183impl LocalLanguageModel {
184 fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
185 let mut messages = TextMessages::new();
186
187 for message in &request.messages {
188 let mut text_content = String::new();
189
190 for content in &message.content {
191 match content {
192 MessageContent::Text(text) => {
193 text_content.push_str(text);
194 }
195 MessageContent::Image { .. } => {
196 // For now, skip image content
197 continue;
198 }
199 MessageContent::ToolResult { .. } => {
200 // Skip tool results for now
201 continue;
202 }
203 MessageContent::Thinking { .. } => {
204 // Skip thinking content
205 continue;
206 }
207 MessageContent::RedactedThinking(_) => {
208 // Skip redacted thinking
209 continue;
210 }
211 MessageContent::ToolUse(_) => {
212 // Skip tool use
213 continue;
214 }
215 }
216 }
217
218 if text_content.is_empty() {
219 continue;
220 }
221
222 let role = match message.role {
223 Role::User => TextMessageRole::User,
224 Role::Assistant => TextMessageRole::Assistant,
225 Role::System => TextMessageRole::System,
226 };
227
228 messages = messages.add_message(role, text_content);
229 }
230
231 messages
232 }
233}
234
235impl LanguageModel for LocalLanguageModel {
236 fn id(&self) -> LanguageModelId {
237 LanguageModelId(DEFAULT_MODEL.into())
238 }
239
240 fn name(&self) -> LanguageModelName {
241 LanguageModelName(DEFAULT_MODEL.into())
242 }
243
244 fn provider_id(&self) -> LanguageModelProviderId {
245 PROVIDER_ID
246 }
247
248 fn provider_name(&self) -> LanguageModelProviderName {
249 PROVIDER_NAME
250 }
251
252 fn telemetry_id(&self) -> String {
253 format!("local/{}", DEFAULT_MODEL)
254 }
255
256 fn supports_tools(&self) -> bool {
257 false
258 }
259
260 fn supports_images(&self) -> bool {
261 false
262 }
263
264 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
265 false
266 }
267
268 fn max_token_count(&self) -> u64 {
269 128000 // GLM-4.5-Air supports 128k context
270 }
271
272 fn count_tokens(
273 &self,
274 request: LanguageModelRequest,
275 _cx: &App,
276 ) -> BoxFuture<'static, Result<u64>> {
277 // Rough estimation: 1 token ≈ 4 characters
278 let mut total_chars = 0;
279 for message in request.messages {
280 for content in message.content {
281 match content {
282 MessageContent::Text(text) => total_chars += text.len(),
283 _ => {}
284 }
285 }
286 }
287 let tokens = (total_chars / 4) as u64;
288 futures::future::ready(Ok(tokens)).boxed()
289 }
290
291 fn stream_completion(
292 &self,
293 request: LanguageModelRequest,
294 cx: &AsyncApp,
295 ) -> BoxFuture<
296 'static,
297 Result<
298 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
299 LanguageModelCompletionError,
300 >,
301 > {
302 let messages = self.to_mistral_messages(&request);
303 let state = self.state.clone();
304 let limiter = self.request_limiter.clone();
305
306 cx.spawn(async move |cx| {
307 let result: Result<
308 BoxStream<
309 'static,
310 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
311 >,
312 LanguageModelCompletionError,
313 > = limiter
314 .run(async move {
315 let model = cx
316 .read_entity(&state, |state, _| state.model.clone())
317 .map_err(|_| {
318 LanguageModelCompletionError::Other(anyhow!("App state dropped"))
319 })?
320 .ok_or_else(|| {
321 LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
322 })?;
323
324 let (mut tx, rx) = mpsc::channel(32);
325
326 // Spawn a task to handle the stream
327 let _ = smol::spawn(async move {
328 let mut stream = match model.stream_chat_request(messages).await {
329 Ok(stream) => stream,
330 Err(e) => {
331 let _ = tx
332 .send(Err(LanguageModelCompletionError::Other(anyhow!(
333 "Failed to start stream: {}",
334 e
335 ))))
336 .await;
337 return;
338 }
339 };
340
341 while let Some(response) = stream.next().await {
342 let event = match response {
343 MistralResponse::Chunk(chunk) => {
344 if let Some(choice) = chunk.choices.first() {
345 if let Some(content) = &choice.delta.content {
346 Some(Ok(LanguageModelCompletionEvent::Text(
347 content.clone(),
348 )))
349 } else if let Some(finish_reason) = &choice.finish_reason {
350 let stop_reason = match finish_reason.as_str() {
351 "stop" => StopReason::EndTurn,
352 "length" => StopReason::MaxTokens,
353 _ => StopReason::EndTurn,
354 };
355 Some(Ok(LanguageModelCompletionEvent::Stop(
356 stop_reason,
357 )))
358 } else {
359 None
360 }
361 } else {
362 None
363 }
364 }
365 MistralResponse::Done(_response) => {
366 // For now, we don't emit usage events since the format doesn't match
367 None
368 }
369 _ => None,
370 };
371
372 if let Some(event) = event {
373 if tx.send(event).await.is_err() {
374 break;
375 }
376 }
377 }
378 })
379 .detach();
380
381 Ok(rx.boxed())
382 })
383 .await;
384
385 result
386 })
387 .boxed()
388 }
389}
390
391struct ConfigurationView {
392 state: Entity<State>,
393}
394
395impl Render for ConfigurationView {
396 fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
397 let status = self.state.read(cx).status.clone();
398
399 div().size_full().child(
400 div()
401 .p_4()
402 .child(
403 div()
404 .flex()
405 .gap_2()
406 .items_center()
407 .child(match &status {
408 ModelStatus::NotLoaded => Label::new("Model not loaded"),
409 ModelStatus::Loading => Label::new("Loading model..."),
410 ModelStatus::Loaded => Label::new("Model loaded"),
411 ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
412 })
413 .child(match &status {
414 ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
415 ModelStatus::Loading => Indicator::dot().color(Color::Modified),
416 ModelStatus::Loaded => Indicator::dot().color(Color::Success),
417 ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
418 }),
419 )
420 .when(!matches!(status, ModelStatus::Loading), |this| {
421 this.child(
422 ButtonLike::new("load_model")
423 .child(Label::new(if matches!(status, ModelStatus::Loaded) {
424 "Reload Model"
425 } else {
426 "Load Model"
427 }))
428 .on_click(cx.listener(|this, _, _window, cx| {
429 this.state.update(cx, |state, cx| {
430 state.authenticate(cx).detach();
431 });
432 })),
433 )
434 }),
435 )
436 }
437}
438
439#[cfg(test)]
440mod tests;