mistral.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use editor::{Editor, EditorElement, EditorStyle};
  5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
  6use gpui::{
  7    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use language_model::{
 11    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 12    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 13    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 14    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
 15    RateLimiter, Role, StopReason, TokenUsage,
 16};
 17use mistral::StreamResponse;
 18pub use settings::MistralAvailableModel as AvailableModel;
 19use settings::{Settings, SettingsStore};
 20use std::collections::HashMap;
 21use std::pin::Pin;
 22use std::str::FromStr;
 23use std::sync::Arc;
 24use strum::IntoEnumIterator;
 25use theme::ThemeSettings;
 26use ui::{Icon, IconName, List, Tooltip, prelude::*};
 27use util::ResultExt;
 28
 29use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 30
 31const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
 32const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
 33
 34#[derive(Default, Clone, Debug, PartialEq)]
 35pub struct MistralSettings {
 36    pub api_url: String,
 37    pub available_models: Vec<AvailableModel>,
 38}
 39
 40pub struct MistralLanguageModelProvider {
 41    http_client: Arc<dyn HttpClient>,
 42    state: gpui::Entity<State>,
 43}
 44
 45pub struct State {
 46    api_key: Option<String>,
 47    api_key_from_env: bool,
 48    _subscription: Subscription,
 49}
 50
 51const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
 52
 53impl State {
 54    fn is_authenticated(&self) -> bool {
 55        self.api_key.is_some()
 56    }
 57
 58    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 59        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 60        let api_url = AllLanguageModelSettings::get_global(cx)
 61            .mistral
 62            .api_url
 63            .clone();
 64        cx.spawn(async move |this, cx| {
 65            credentials_provider
 66                .delete_credentials(&api_url, cx)
 67                .await
 68                .log_err();
 69            this.update(cx, |this, cx| {
 70                this.api_key = None;
 71                this.api_key_from_env = false;
 72                cx.notify();
 73            })
 74        })
 75    }
 76
 77    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 78        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 79        let api_url = AllLanguageModelSettings::get_global(cx)
 80            .mistral
 81            .api_url
 82            .clone();
 83        cx.spawn(async move |this, cx| {
 84            credentials_provider
 85                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
 86                .await?;
 87            this.update(cx, |this, cx| {
 88                this.api_key = Some(api_key);
 89                cx.notify();
 90            })
 91        })
 92    }
 93
 94    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 95        if self.is_authenticated() {
 96            return Task::ready(Ok(()));
 97        }
 98
 99        let credentials_provider = <dyn CredentialsProvider>::global(cx);
100        let api_url = AllLanguageModelSettings::get_global(cx)
101            .mistral
102            .api_url
103            .clone();
104        cx.spawn(async move |this, cx| {
105            let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
106                (api_key, true)
107            } else {
108                let (_, api_key) = credentials_provider
109                    .read_credentials(&api_url, cx)
110                    .await?
111                    .ok_or(AuthenticateError::CredentialsNotFound)?;
112                (
113                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
114                    false,
115                )
116            };
117            this.update(cx, |this, cx| {
118                this.api_key = Some(api_key);
119                this.api_key_from_env = from_env;
120                cx.notify();
121            })?;
122
123            Ok(())
124        })
125    }
126}
127
128impl MistralLanguageModelProvider {
129    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
130        let state = cx.new(|cx| State {
131            api_key: None,
132            api_key_from_env: false,
133            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
134                cx.notify();
135            }),
136        });
137
138        Self { http_client, state }
139    }
140
141    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
142        Arc::new(MistralLanguageModel {
143            id: LanguageModelId::from(model.id().to_string()),
144            model,
145            state: self.state.clone(),
146            http_client: self.http_client.clone(),
147            request_limiter: RateLimiter::new(4),
148        })
149    }
150}
151
152impl LanguageModelProviderState for MistralLanguageModelProvider {
153    type ObservableEntity = State;
154
155    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
156        Some(self.state.clone())
157    }
158}
159
160impl LanguageModelProvider for MistralLanguageModelProvider {
161    fn id(&self) -> LanguageModelProviderId {
162        PROVIDER_ID
163    }
164
165    fn name(&self) -> LanguageModelProviderName {
166        PROVIDER_NAME
167    }
168
169    fn icon(&self) -> IconName {
170        IconName::AiMistral
171    }
172
173    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
174        Some(self.create_language_model(mistral::Model::default()))
175    }
176
177    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
178        Some(self.create_language_model(mistral::Model::default_fast()))
179    }
180
181    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
182        let mut models = BTreeMap::default();
183
184        // Add base models from mistral::Model::iter()
185        for model in mistral::Model::iter() {
186            if !matches!(model, mistral::Model::Custom { .. }) {
187                models.insert(model.id().to_string(), model);
188            }
189        }
190
191        // Override with available models from settings
192        for model in &AllLanguageModelSettings::get_global(cx)
193            .mistral
194            .available_models
195        {
196            models.insert(
197                model.name.clone(),
198                mistral::Model::Custom {
199                    name: model.name.clone(),
200                    display_name: model.display_name.clone(),
201                    max_tokens: model.max_tokens,
202                    max_output_tokens: model.max_output_tokens,
203                    max_completion_tokens: model.max_completion_tokens,
204                    supports_tools: model.supports_tools,
205                    supports_images: model.supports_images,
206                    supports_thinking: model.supports_thinking,
207                },
208            );
209        }
210
211        models
212            .into_values()
213            .map(|model| {
214                Arc::new(MistralLanguageModel {
215                    id: LanguageModelId::from(model.id().to_string()),
216                    model,
217                    state: self.state.clone(),
218                    http_client: self.http_client.clone(),
219                    request_limiter: RateLimiter::new(4),
220                }) as Arc<dyn LanguageModel>
221            })
222            .collect()
223    }
224
225    fn is_authenticated(&self, cx: &App) -> bool {
226        self.state.read(cx).is_authenticated()
227    }
228
229    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
230        self.state.update(cx, |state, cx| state.authenticate(cx))
231    }
232
233    fn configuration_view(
234        &self,
235        _target_agent: language_model::ConfigurationViewTargetAgent,
236        window: &mut Window,
237        cx: &mut App,
238    ) -> AnyView {
239        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
240            .into()
241    }
242
243    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
244        self.state.update(cx, |state, cx| state.reset_api_key(cx))
245    }
246}
247
248pub struct MistralLanguageModel {
249    id: LanguageModelId,
250    model: mistral::Model,
251    state: gpui::Entity<State>,
252    http_client: Arc<dyn HttpClient>,
253    request_limiter: RateLimiter,
254}
255
256impl MistralLanguageModel {
257    fn stream_completion(
258        &self,
259        request: mistral::Request,
260        cx: &AsyncApp,
261    ) -> BoxFuture<
262        'static,
263        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
264    > {
265        let http_client = self.http_client.clone();
266        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
267            let settings = &AllLanguageModelSettings::get_global(cx).mistral;
268            (state.api_key.clone(), settings.api_url.clone())
269        }) else {
270            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
271        };
272
273        let future = self.request_limiter.stream(async move {
274            let api_key = api_key.context("Missing Mistral API Key")?;
275            let request =
276                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
277            let response = request.await?;
278            Ok(response)
279        });
280
281        async move { Ok(future.await?.boxed()) }.boxed()
282    }
283}
284
285impl LanguageModel for MistralLanguageModel {
286    fn id(&self) -> LanguageModelId {
287        self.id.clone()
288    }
289
290    fn name(&self) -> LanguageModelName {
291        LanguageModelName::from(self.model.display_name().to_string())
292    }
293
294    fn provider_id(&self) -> LanguageModelProviderId {
295        PROVIDER_ID
296    }
297
298    fn provider_name(&self) -> LanguageModelProviderName {
299        PROVIDER_NAME
300    }
301
302    fn supports_tools(&self) -> bool {
303        self.model.supports_tools()
304    }
305
306    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
307        self.model.supports_tools()
308    }
309
310    fn supports_images(&self) -> bool {
311        self.model.supports_images()
312    }
313
314    fn telemetry_id(&self) -> String {
315        format!("mistral/{}", self.model.id())
316    }
317
318    fn max_token_count(&self) -> u64 {
319        self.model.max_token_count()
320    }
321
322    fn max_output_tokens(&self) -> Option<u64> {
323        self.model.max_output_tokens()
324    }
325
326    fn count_tokens(
327        &self,
328        request: LanguageModelRequest,
329        cx: &App,
330    ) -> BoxFuture<'static, Result<u64>> {
331        cx.background_spawn(async move {
332            let messages = request
333                .messages
334                .into_iter()
335                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
336                    role: match message.role {
337                        Role::User => "user".into(),
338                        Role::Assistant => "assistant".into(),
339                        Role::System => "system".into(),
340                    },
341                    content: Some(message.string_contents()),
342                    name: None,
343                    function_call: None,
344                })
345                .collect::<Vec<_>>();
346
347            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
348        })
349        .boxed()
350    }
351
352    fn stream_completion(
353        &self,
354        request: LanguageModelRequest,
355        cx: &AsyncApp,
356    ) -> BoxFuture<
357        'static,
358        Result<
359            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
360            LanguageModelCompletionError,
361        >,
362    > {
363        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
364        let stream = self.stream_completion(request, cx);
365
366        async move {
367            let stream = stream.await?;
368            let mapper = MistralEventMapper::new();
369            Ok(mapper.map_stream(stream).boxed())
370        }
371        .boxed()
372    }
373}
374
375pub fn into_mistral(
376    request: LanguageModelRequest,
377    model: mistral::Model,
378    max_output_tokens: Option<u64>,
379) -> mistral::Request {
380    let stream = true;
381
382    let mut messages = Vec::new();
383    for message in &request.messages {
384        match message.role {
385            Role::User => {
386                let mut message_content = mistral::MessageContent::empty();
387                for content in &message.content {
388                    match content {
389                        MessageContent::Text(text) => {
390                            message_content
391                                .push_part(mistral::MessagePart::Text { text: text.clone() });
392                        }
393                        MessageContent::Image(image_content) => {
394                            if model.supports_images() {
395                                message_content.push_part(mistral::MessagePart::ImageUrl {
396                                    image_url: image_content.to_base64_url(),
397                                });
398                            }
399                        }
400                        MessageContent::Thinking { text, .. } => {
401                            if model.supports_thinking() {
402                                message_content.push_part(mistral::MessagePart::Thinking {
403                                    thinking: vec![mistral::ThinkingPart::Text {
404                                        text: text.clone(),
405                                    }],
406                                });
407                            }
408                        }
409                        MessageContent::RedactedThinking(_) => {}
410                        MessageContent::ToolUse(_) => {
411                            // Tool use is not supported in User messages for Mistral
412                        }
413                        MessageContent::ToolResult(tool_result) => {
414                            let tool_content = match &tool_result.content {
415                                LanguageModelToolResultContent::Text(text) => text.to_string(),
416                                LanguageModelToolResultContent::Image(_) => {
417                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
418                                }
419                            };
420                            messages.push(mistral::RequestMessage::Tool {
421                                content: tool_content,
422                                tool_call_id: tool_result.tool_use_id.to_string(),
423                            });
424                        }
425                    }
426                }
427                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
428                {
429                    messages.push(mistral::RequestMessage::User {
430                        content: message_content,
431                    });
432                }
433            }
434            Role::Assistant => {
435                for content in &message.content {
436                    match content {
437                        MessageContent::Text(text) => {
438                            messages.push(mistral::RequestMessage::Assistant {
439                                content: Some(mistral::MessageContent::Plain {
440                                    content: text.clone(),
441                                }),
442                                tool_calls: Vec::new(),
443                            });
444                        }
445                        MessageContent::Thinking { text, .. } => {
446                            if model.supports_thinking() {
447                                messages.push(mistral::RequestMessage::Assistant {
448                                    content: Some(mistral::MessageContent::Multipart {
449                                        content: vec![mistral::MessagePart::Thinking {
450                                            thinking: vec![mistral::ThinkingPart::Text {
451                                                text: text.clone(),
452                                            }],
453                                        }],
454                                    }),
455                                    tool_calls: Vec::new(),
456                                });
457                            }
458                        }
459                        MessageContent::RedactedThinking(_) => {}
460                        MessageContent::Image(_) => {}
461                        MessageContent::ToolUse(tool_use) => {
462                            let tool_call = mistral::ToolCall {
463                                id: tool_use.id.to_string(),
464                                content: mistral::ToolCallContent::Function {
465                                    function: mistral::FunctionContent {
466                                        name: tool_use.name.to_string(),
467                                        arguments: serde_json::to_string(&tool_use.input)
468                                            .unwrap_or_default(),
469                                    },
470                                },
471                            };
472
473                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
474                                messages.last_mut()
475                            {
476                                tool_calls.push(tool_call);
477                            } else {
478                                messages.push(mistral::RequestMessage::Assistant {
479                                    content: None,
480                                    tool_calls: vec![tool_call],
481                                });
482                            }
483                        }
484                        MessageContent::ToolResult(_) => {
485                            // Tool results are not supported in Assistant messages
486                        }
487                    }
488                }
489            }
490            Role::System => {
491                for content in &message.content {
492                    match content {
493                        MessageContent::Text(text) => {
494                            messages.push(mistral::RequestMessage::System {
495                                content: mistral::MessageContent::Plain {
496                                    content: text.clone(),
497                                },
498                            });
499                        }
500                        MessageContent::Thinking { text, .. } => {
501                            if model.supports_thinking() {
502                                messages.push(mistral::RequestMessage::System {
503                                    content: mistral::MessageContent::Multipart {
504                                        content: vec![mistral::MessagePart::Thinking {
505                                            thinking: vec![mistral::ThinkingPart::Text {
506                                                text: text.clone(),
507                                            }],
508                                        }],
509                                    },
510                                });
511                            }
512                        }
513                        MessageContent::RedactedThinking(_) => {}
514                        MessageContent::Image(_)
515                        | MessageContent::ToolUse(_)
516                        | MessageContent::ToolResult(_) => {
517                            // Images and tools are not supported in System messages
518                        }
519                    }
520                }
521            }
522        }
523    }
524
525    mistral::Request {
526        model: model.id().to_string(),
527        messages,
528        stream,
529        max_tokens: max_output_tokens,
530        temperature: request.temperature,
531        response_format: None,
532        tool_choice: match request.tool_choice {
533            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
534                Some(mistral::ToolChoice::Auto)
535            }
536            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
537                Some(mistral::ToolChoice::Any)
538            }
539            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
540            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
541            _ => None,
542        },
543        parallel_tool_calls: if !request.tools.is_empty() {
544            Some(false)
545        } else {
546            None
547        },
548        tools: request
549            .tools
550            .into_iter()
551            .map(|tool| mistral::ToolDefinition::Function {
552                function: mistral::FunctionDefinition {
553                    name: tool.name,
554                    description: Some(tool.description),
555                    parameters: Some(tool.input_schema),
556                },
557            })
558            .collect(),
559    }
560}
561
562pub struct MistralEventMapper {
563    tool_calls_by_index: HashMap<usize, RawToolCall>,
564}
565
566impl MistralEventMapper {
567    pub fn new() -> Self {
568        Self {
569            tool_calls_by_index: HashMap::default(),
570        }
571    }
572
573    pub fn map_stream(
574        mut self,
575        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
576    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
577    {
578        events.flat_map(move |event| {
579            futures::stream::iter(match event {
580                Ok(event) => self.map_event(event),
581                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
582            })
583        })
584    }
585
586    pub fn map_event(
587        &mut self,
588        event: mistral::StreamResponse,
589    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
590        let Some(choice) = event.choices.first() else {
591            return vec![Err(LanguageModelCompletionError::from(anyhow!(
592                "Response contained no choices"
593            )))];
594        };
595
596        let mut events = Vec::new();
597        if let Some(content) = choice.delta.content.as_ref() {
598            match content {
599                mistral::MessageContentDelta::Text(text) => {
600                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
601                }
602                mistral::MessageContentDelta::Parts(parts) => {
603                    for part in parts {
604                        match part {
605                            mistral::MessagePart::Text { text } => {
606                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
607                            }
608                            mistral::MessagePart::Thinking { thinking } => {
609                                for tp in thinking.iter().cloned() {
610                                    match tp {
611                                        mistral::ThinkingPart::Text { text } => {
612                                            events.push(Ok(
613                                                LanguageModelCompletionEvent::Thinking {
614                                                    text,
615                                                    signature: None,
616                                                },
617                                            ));
618                                        }
619                                    }
620                                }
621                            }
622                            mistral::MessagePart::ImageUrl { .. } => {
623                                // We currently don't emit a separate event for images in responses.
624                            }
625                        }
626                    }
627                }
628            }
629        }
630
631        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
632            for tool_call in tool_calls {
633                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
634
635                if let Some(tool_id) = tool_call.id.clone() {
636                    entry.id = tool_id;
637                }
638
639                if let Some(function) = tool_call.function.as_ref() {
640                    if let Some(name) = function.name.clone() {
641                        entry.name = name;
642                    }
643
644                    if let Some(arguments) = function.arguments.clone() {
645                        entry.arguments.push_str(&arguments);
646                    }
647                }
648            }
649        }
650
651        if let Some(usage) = event.usage {
652            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
653                input_tokens: usage.prompt_tokens,
654                output_tokens: usage.completion_tokens,
655                cache_creation_input_tokens: 0,
656                cache_read_input_tokens: 0,
657            })));
658        }
659
660        if let Some(finish_reason) = choice.finish_reason.as_deref() {
661            match finish_reason {
662                "stop" => {
663                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
664                }
665                "tool_calls" => {
666                    events.extend(self.process_tool_calls());
667                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
668                }
669                unexpected => {
670                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
671                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
672                }
673            }
674        }
675
676        events
677    }
678
679    fn process_tool_calls(
680        &mut self,
681    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
682        let mut results = Vec::new();
683
684        for (_, tool_call) in self.tool_calls_by_index.drain() {
685            if tool_call.id.is_empty() || tool_call.name.is_empty() {
686                results.push(Err(LanguageModelCompletionError::from(anyhow!(
687                    "Received incomplete tool call: missing id or name"
688                ))));
689                continue;
690            }
691
692            match serde_json::Value::from_str(&tool_call.arguments) {
693                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
694                    LanguageModelToolUse {
695                        id: tool_call.id.into(),
696                        name: tool_call.name.into(),
697                        is_input_complete: true,
698                        input,
699                        raw_input: tool_call.arguments,
700                    },
701                ))),
702                Err(error) => {
703                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
704                        id: tool_call.id.into(),
705                        tool_name: tool_call.name.into(),
706                        raw_input: tool_call.arguments.into(),
707                        json_parse_error: error.to_string(),
708                    }))
709                }
710            }
711        }
712
713        results
714    }
715}
716
717#[derive(Default)]
718struct RawToolCall {
719    id: String,
720    name: String,
721    arguments: String,
722}
723
724struct ConfigurationView {
725    api_key_editor: Entity<Editor>,
726    state: gpui::Entity<State>,
727    load_credentials_task: Option<Task<()>>,
728}
729
730impl ConfigurationView {
731    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
732        let api_key_editor = cx.new(|cx| {
733            let mut editor = Editor::single_line(window, cx);
734            editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", window, cx);
735            editor
736        });
737
738        cx.observe(&state, |_, _, cx| {
739            cx.notify();
740        })
741        .detach();
742
743        let load_credentials_task = Some(cx.spawn_in(window, {
744            let state = state.clone();
745            async move |this, cx| {
746                if let Some(task) = state
747                    .update(cx, |state, cx| state.authenticate(cx))
748                    .log_err()
749                {
750                    // We don't log an error, because "not signed in" is also an error.
751                    let _ = task.await;
752                }
753
754                this.update(cx, |this, cx| {
755                    this.load_credentials_task = None;
756                    cx.notify();
757                })
758                .log_err();
759            }
760        }));
761
762        Self {
763            api_key_editor,
764            state,
765            load_credentials_task,
766        }
767    }
768
769    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
770        let api_key = self.api_key_editor.read(cx).text(cx);
771        if api_key.is_empty() {
772            return;
773        }
774
775        let state = self.state.clone();
776        cx.spawn_in(window, async move |_, cx| {
777            state
778                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
779                .await
780        })
781        .detach_and_log_err(cx);
782
783        cx.notify();
784    }
785
786    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
787        self.api_key_editor
788            .update(cx, |editor, cx| editor.set_text("", window, cx));
789
790        let state = self.state.clone();
791        cx.spawn_in(window, async move |_, cx| {
792            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
793        })
794        .detach_and_log_err(cx);
795
796        cx.notify();
797    }
798
799    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
800        let settings = ThemeSettings::get_global(cx);
801        let text_style = TextStyle {
802            color: cx.theme().colors().text,
803            font_family: settings.ui_font.family.clone(),
804            font_features: settings.ui_font.features.clone(),
805            font_fallbacks: settings.ui_font.fallbacks.clone(),
806            font_size: rems(0.875).into(),
807            font_weight: settings.ui_font.weight,
808            font_style: FontStyle::Normal,
809            line_height: relative(1.3),
810            white_space: WhiteSpace::Normal,
811            ..Default::default()
812        };
813        EditorElement::new(
814            &self.api_key_editor,
815            EditorStyle {
816                background: cx.theme().colors().editor_background,
817                local_player: cx.theme().players().local(),
818                text: text_style,
819                ..Default::default()
820            },
821        )
822    }
823
824    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
825        !self.state.read(cx).is_authenticated()
826    }
827}
828
829impl Render for ConfigurationView {
830    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
831        let env_var_set = self.state.read(cx).api_key_from_env;
832
833        if self.load_credentials_task.is_some() {
834            div().child(Label::new("Loading credentials...")).into_any()
835        } else if self.should_render_editor(cx) {
836            v_flex()
837                .size_full()
838                .on_action(cx.listener(Self::save_api_key))
839                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
840                .child(
841                    List::new()
842                        .child(InstructionListItem::new(
843                            "Create one by visiting",
844                            Some("Mistral's console"),
845                            Some("https://console.mistral.ai/api-keys"),
846                        ))
847                        .child(InstructionListItem::text_only(
848                            "Ensure your Mistral account has credits",
849                        ))
850                        .child(InstructionListItem::text_only(
851                            "Paste your API key below and hit enter to start using the assistant",
852                        )),
853                )
854                .child(
855                    h_flex()
856                        .w_full()
857                        .my_2()
858                        .px_2()
859                        .py_1()
860                        .bg(cx.theme().colors().editor_background)
861                        .border_1()
862                        .border_color(cx.theme().colors().border)
863                        .rounded_sm()
864                        .child(self.render_api_key_editor(cx)),
865                )
866                .child(
867                    Label::new(
868                        format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
869                    )
870                    .size(LabelSize::Small).color(Color::Muted),
871                )
872                .into_any()
873        } else {
874            h_flex()
875                .mt_1()
876                .p_1()
877                .justify_between()
878                .rounded_md()
879                .border_1()
880                .border_color(cx.theme().colors().border)
881                .bg(cx.theme().colors().background)
882                .child(
883                    h_flex()
884                        .gap_1()
885                        .child(Icon::new(IconName::Check).color(Color::Success))
886                        .child(Label::new(if env_var_set {
887                            format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
888                        } else {
889                            "API key configured.".to_string()
890                        })),
891                )
892                .child(
893                    Button::new("reset-key", "Reset Key")
894                        .label_size(LabelSize::Small)
895                        .icon(Some(IconName::Trash))
896                        .icon_size(IconSize::Small)
897                        .icon_position(IconPosition::Start)
898                        .disabled(env_var_set)
899                        .when(env_var_set, |this| {
900                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
901                        })
902                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
903                )
904                .into_any()
905        }
906    }
907}
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
913
914    #[test]
915    fn test_into_mistral_basic_conversion() {
916        let request = LanguageModelRequest {
917            messages: vec![
918                LanguageModelRequestMessage {
919                    role: Role::System,
920                    content: vec![MessageContent::Text("System prompt".into())],
921                    cache: false,
922                },
923                LanguageModelRequestMessage {
924                    role: Role::User,
925                    content: vec![MessageContent::Text("Hello".into())],
926                    cache: false,
927                },
928            ],
929            temperature: Some(0.5),
930            tools: vec![],
931            tool_choice: None,
932            thread_id: None,
933            prompt_id: None,
934            intent: None,
935            mode: None,
936            stop: vec![],
937            thinking_allowed: true,
938        };
939
940        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
941
942        assert_eq!(mistral_request.model, "mistral-small-latest");
943        assert_eq!(mistral_request.temperature, Some(0.5));
944        assert_eq!(mistral_request.messages.len(), 2);
945        assert!(mistral_request.stream);
946    }
947
948    #[test]
949    fn test_into_mistral_with_image() {
950        let request = LanguageModelRequest {
951            messages: vec![LanguageModelRequestMessage {
952                role: Role::User,
953                content: vec![
954                    MessageContent::Text("What's in this image?".into()),
955                    MessageContent::Image(LanguageModelImage {
956                        source: "base64data".into(),
957                        size: Default::default(),
958                    }),
959                ],
960                cache: false,
961            }],
962            tools: vec![],
963            tool_choice: None,
964            temperature: None,
965            thread_id: None,
966            prompt_id: None,
967            intent: None,
968            mode: None,
969            stop: vec![],
970            thinking_allowed: true,
971        };
972
973        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
974
975        assert_eq!(mistral_request.messages.len(), 1);
976        assert!(matches!(
977            &mistral_request.messages[0],
978            mistral::RequestMessage::User {
979                content: mistral::MessageContent::Multipart { .. }
980            }
981        ));
982
983        if let mistral::RequestMessage::User {
984            content: mistral::MessageContent::Multipart { content },
985        } = &mistral_request.messages[0]
986        {
987            assert_eq!(content.len(), 2);
988            assert!(matches!(
989                &content[0],
990                mistral::MessagePart::Text { text } if text == "What's in this image?"
991            ));
992            assert!(matches!(
993                &content[1],
994                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
995            ));
996        }
997    }
998}