1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4
5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
10 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
13 LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
14};
15pub use mistral::{MISTRAL_API_URL, StreamResponse};
16pub use settings::MistralAvailableModel as AvailableModel;
17use settings::{Settings, SettingsStore};
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
23use ui_input::InputField;
24use util::ResultExt;
25
26use language_model::util::{fix_streamed_json, parse_tool_arguments};
27
28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
30
31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
33
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct MistralSettings {
36 pub api_url: String,
37 pub available_models: Vec<AvailableModel>,
38}
39
40pub struct MistralLanguageModelProvider {
41 http_client: Arc<dyn HttpClient>,
42 pub state: Entity<State>,
43}
44
45pub struct State {
46 api_key_state: ApiKeyState,
47 credentials_provider: Arc<dyn CredentialsProvider>,
48}
49
50impl State {
51 fn is_authenticated(&self) -> bool {
52 self.api_key_state.has_key()
53 }
54
55 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
56 let credentials_provider = self.credentials_provider.clone();
57 let api_url = MistralLanguageModelProvider::api_url(cx);
58 self.api_key_state.store(
59 api_url,
60 api_key,
61 |this| &mut this.api_key_state,
62 credentials_provider,
63 cx,
64 )
65 }
66
67 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
68 let credentials_provider = self.credentials_provider.clone();
69 let api_url = MistralLanguageModelProvider::api_url(cx);
70 self.api_key_state.load_if_needed(
71 api_url,
72 |this| &mut this.api_key_state,
73 credentials_provider,
74 cx,
75 )
76 }
77}
78
79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
80
81impl Global for GlobalMistralLanguageModelProvider {}
82
83impl MistralLanguageModelProvider {
84 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
85 cx.try_global::<GlobalMistralLanguageModelProvider>()
86 .map(|this| &this.0)
87 }
88
89 pub fn global(
90 http_client: Arc<dyn HttpClient>,
91 credentials_provider: Arc<dyn CredentialsProvider>,
92 cx: &mut App,
93 ) -> Arc<Self> {
94 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
95 return this.0.clone();
96 }
97 let state = cx.new(|cx| {
98 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
99 let credentials_provider = this.credentials_provider.clone();
100 let api_url = Self::api_url(cx);
101 this.api_key_state.handle_url_change(
102 api_url,
103 |this| &mut this.api_key_state,
104 credentials_provider,
105 cx,
106 );
107 cx.notify();
108 })
109 .detach();
110 State {
111 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
112 credentials_provider,
113 }
114 });
115
116 let this = Arc::new(Self { http_client, state });
117 cx.set_global(GlobalMistralLanguageModelProvider(this));
118 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
119 }
120
121 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
122 Arc::new(MistralLanguageModel {
123 id: LanguageModelId::from(model.id().to_string()),
124 model,
125 state: self.state.clone(),
126 http_client: self.http_client.clone(),
127 request_limiter: RateLimiter::new(4),
128 })
129 }
130
131 fn settings(cx: &App) -> &MistralSettings {
132 &crate::AllLanguageModelSettings::get_global(cx).mistral
133 }
134
135 pub fn api_url(cx: &App) -> SharedString {
136 let api_url = &Self::settings(cx).api_url;
137 if api_url.is_empty() {
138 mistral::MISTRAL_API_URL.into()
139 } else {
140 SharedString::new(api_url.as_str())
141 }
142 }
143}
144
145impl LanguageModelProviderState for MistralLanguageModelProvider {
146 type ObservableEntity = State;
147
148 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
149 Some(self.state.clone())
150 }
151}
152
153impl LanguageModelProvider for MistralLanguageModelProvider {
154 fn id(&self) -> LanguageModelProviderId {
155 PROVIDER_ID
156 }
157
158 fn name(&self) -> LanguageModelProviderName {
159 PROVIDER_NAME
160 }
161
162 fn icon(&self) -> IconOrSvg {
163 IconOrSvg::Icon(IconName::AiMistral)
164 }
165
166 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
167 Some(self.create_language_model(mistral::Model::default()))
168 }
169
170 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
171 Some(self.create_language_model(mistral::Model::default_fast()))
172 }
173
174 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
175 let mut models = BTreeMap::default();
176
177 // Add base models from mistral::Model::iter()
178 for model in mistral::Model::iter() {
179 if !matches!(model, mistral::Model::Custom { .. }) {
180 models.insert(model.id().to_string(), model);
181 }
182 }
183
184 // Override with available models from settings
185 for model in &Self::settings(cx).available_models {
186 models.insert(
187 model.name.clone(),
188 mistral::Model::Custom {
189 name: model.name.clone(),
190 display_name: model.display_name.clone(),
191 max_tokens: model.max_tokens,
192 max_output_tokens: model.max_output_tokens,
193 max_completion_tokens: model.max_completion_tokens,
194 supports_tools: model.supports_tools,
195 supports_images: model.supports_images,
196 supports_thinking: model.supports_thinking,
197 },
198 );
199 }
200
201 models
202 .into_values()
203 .map(|model| {
204 Arc::new(MistralLanguageModel {
205 id: LanguageModelId::from(model.id().to_string()),
206 model,
207 state: self.state.clone(),
208 http_client: self.http_client.clone(),
209 request_limiter: RateLimiter::new(4),
210 }) as Arc<dyn LanguageModel>
211 })
212 .collect()
213 }
214
215 fn is_authenticated(&self, cx: &App) -> bool {
216 self.state.read(cx).is_authenticated()
217 }
218
219 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
220 self.state.update(cx, |state, cx| state.authenticate(cx))
221 }
222
223 fn configuration_view(
224 &self,
225 _target_agent: language_model::ConfigurationViewTargetAgent,
226 window: &mut Window,
227 cx: &mut App,
228 ) -> AnyView {
229 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
230 .into()
231 }
232
233 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
234 self.state
235 .update(cx, |state, cx| state.set_api_key(None, cx))
236 }
237}
238
239pub struct MistralLanguageModel {
240 id: LanguageModelId,
241 model: mistral::Model,
242 state: Entity<State>,
243 http_client: Arc<dyn HttpClient>,
244 request_limiter: RateLimiter,
245}
246
247impl MistralLanguageModel {
248 fn stream_completion(
249 &self,
250 request: mistral::Request,
251 affinity: Option<String>,
252 cx: &AsyncApp,
253 ) -> BoxFuture<
254 'static,
255 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
256 > {
257 let http_client = self.http_client.clone();
258
259 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
260 let api_url = MistralLanguageModelProvider::api_url(cx);
261 (state.api_key_state.key(&api_url), api_url)
262 });
263
264 let future = self.request_limiter.stream(async move {
265 let Some(api_key) = api_key else {
266 return Err(LanguageModelCompletionError::NoApiKey {
267 provider: PROVIDER_NAME,
268 });
269 };
270 let request = mistral::stream_completion(
271 http_client.as_ref(),
272 &api_url,
273 &api_key,
274 request,
275 affinity,
276 );
277 let response = request.await?;
278 Ok(response)
279 });
280
281 async move { Ok(future.await?.boxed()) }.boxed()
282 }
283}
284
285impl LanguageModel for MistralLanguageModel {
286 fn id(&self) -> LanguageModelId {
287 self.id.clone()
288 }
289
290 fn name(&self) -> LanguageModelName {
291 LanguageModelName::from(self.model.display_name().to_string())
292 }
293
294 fn provider_id(&self) -> LanguageModelProviderId {
295 PROVIDER_ID
296 }
297
298 fn provider_name(&self) -> LanguageModelProviderName {
299 PROVIDER_NAME
300 }
301
302 fn supports_tools(&self) -> bool {
303 self.model.supports_tools()
304 }
305
306 fn supports_streaming_tools(&self) -> bool {
307 true
308 }
309
310 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
311 self.model.supports_tools()
312 }
313
314 fn supports_images(&self) -> bool {
315 self.model.supports_images()
316 }
317
318 fn telemetry_id(&self) -> String {
319 format!("mistral/{}", self.model.id())
320 }
321
322 fn max_token_count(&self) -> u64 {
323 self.model.max_token_count()
324 }
325
326 fn max_output_tokens(&self) -> Option<u64> {
327 self.model.max_output_tokens()
328 }
329
330 fn count_tokens(
331 &self,
332 request: LanguageModelRequest,
333 cx: &App,
334 ) -> BoxFuture<'static, Result<u64>> {
335 cx.background_spawn(async move {
336 let messages = request
337 .messages
338 .into_iter()
339 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
340 role: match message.role {
341 Role::User => "user".into(),
342 Role::Assistant => "assistant".into(),
343 Role::System => "system".into(),
344 },
345 content: Some(message.string_contents()),
346 name: None,
347 function_call: None,
348 })
349 .collect::<Vec<_>>();
350
351 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
352 })
353 .boxed()
354 }
355
356 fn stream_completion(
357 &self,
358 request: LanguageModelRequest,
359 cx: &AsyncApp,
360 ) -> BoxFuture<
361 'static,
362 Result<
363 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
364 LanguageModelCompletionError,
365 >,
366 > {
367 let (request, affinity) =
368 into_mistral(request, self.model.clone(), self.max_output_tokens());
369 let stream = self.stream_completion(request, affinity, cx);
370
371 async move {
372 let stream = stream.await?;
373 let mapper = MistralEventMapper::new();
374 Ok(mapper.map_stream(stream).boxed())
375 }
376 .boxed()
377 }
378}
379
380pub fn into_mistral(
381 request: LanguageModelRequest,
382 model: mistral::Model,
383 max_output_tokens: Option<u64>,
384) -> (mistral::Request, Option<String>) {
385 let stream = true;
386
387 let mut messages = Vec::new();
388 for message in &request.messages {
389 match message.role {
390 Role::User => {
391 let mut message_content = mistral::MessageContent::empty();
392 for content in &message.content {
393 match content {
394 MessageContent::Text(text) => {
395 message_content
396 .push_part(mistral::MessagePart::Text { text: text.clone() });
397 }
398 MessageContent::Image(image_content) => {
399 if model.supports_images() {
400 message_content.push_part(mistral::MessagePart::ImageUrl {
401 image_url: image_content.to_base64_url(),
402 });
403 }
404 }
405 MessageContent::Thinking { text, .. } => {
406 if model.supports_thinking() {
407 message_content.push_part(mistral::MessagePart::Thinking {
408 thinking: vec![mistral::ThinkingPart::Text {
409 text: text.clone(),
410 }],
411 });
412 }
413 }
414 MessageContent::RedactedThinking(_) => {}
415 MessageContent::ToolUse(_) => {
416 // Tool use is not supported in User messages for Mistral
417 }
418 MessageContent::ToolResult(tool_result) => {
419 let tool_content = match &tool_result.content {
420 LanguageModelToolResultContent::Text(text) => text.to_string(),
421 LanguageModelToolResultContent::Image(_) => {
422 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
423 }
424 };
425 messages.push(mistral::RequestMessage::Tool {
426 content: tool_content,
427 tool_call_id: tool_result.tool_use_id.to_string(),
428 });
429 }
430 }
431 }
432 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
433 {
434 messages.push(mistral::RequestMessage::User {
435 content: message_content,
436 });
437 }
438 }
439 Role::Assistant => {
440 for content in &message.content {
441 match content {
442 MessageContent::Text(text) if text.is_empty() => {
443 // Mistral API returns a 400 if there's neither content nor tool_calls
444 }
445 MessageContent::Text(text) => {
446 messages.push(mistral::RequestMessage::Assistant {
447 content: Some(mistral::MessageContent::Plain {
448 content: text.clone(),
449 }),
450 tool_calls: Vec::new(),
451 });
452 }
453 MessageContent::Thinking { text, .. } => {
454 if model.supports_thinking() {
455 messages.push(mistral::RequestMessage::Assistant {
456 content: Some(mistral::MessageContent::Multipart {
457 content: vec![mistral::MessagePart::Thinking {
458 thinking: vec![mistral::ThinkingPart::Text {
459 text: text.clone(),
460 }],
461 }],
462 }),
463 tool_calls: Vec::new(),
464 });
465 }
466 }
467 MessageContent::RedactedThinking(_) => {}
468 MessageContent::Image(_) => {}
469 MessageContent::ToolUse(tool_use) => {
470 let tool_call = mistral::ToolCall {
471 id: tool_use.id.to_string(),
472 content: mistral::ToolCallContent::Function {
473 function: mistral::FunctionContent {
474 name: tool_use.name.to_string(),
475 arguments: serde_json::to_string(&tool_use.input)
476 .unwrap_or_default(),
477 },
478 },
479 };
480
481 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
482 messages.last_mut()
483 {
484 tool_calls.push(tool_call);
485 } else {
486 messages.push(mistral::RequestMessage::Assistant {
487 content: None,
488 tool_calls: vec![tool_call],
489 });
490 }
491 }
492 MessageContent::ToolResult(_) => {
493 // Tool results are not supported in Assistant messages
494 }
495 }
496 }
497 }
498 Role::System => {
499 for content in &message.content {
500 match content {
501 MessageContent::Text(text) => {
502 messages.push(mistral::RequestMessage::System {
503 content: mistral::MessageContent::Plain {
504 content: text.clone(),
505 },
506 });
507 }
508 MessageContent::Thinking { text, .. } => {
509 if model.supports_thinking() {
510 messages.push(mistral::RequestMessage::System {
511 content: mistral::MessageContent::Multipart {
512 content: vec![mistral::MessagePart::Thinking {
513 thinking: vec![mistral::ThinkingPart::Text {
514 text: text.clone(),
515 }],
516 }],
517 },
518 });
519 }
520 }
521 MessageContent::RedactedThinking(_) => {}
522 MessageContent::Image(_)
523 | MessageContent::ToolUse(_)
524 | MessageContent::ToolResult(_) => {
525 // Images and tools are not supported in System messages
526 }
527 }
528 }
529 }
530 }
531 }
532
533 (
534 mistral::Request {
535 model: model.id().to_string(),
536 messages,
537 stream,
538 stream_options: if stream {
539 Some(mistral::StreamOptions {
540 stream_tool_calls: Some(true),
541 })
542 } else {
543 None
544 },
545 max_tokens: max_output_tokens,
546 temperature: request.temperature,
547 response_format: None,
548 tool_choice: match request.tool_choice {
549 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
550 Some(mistral::ToolChoice::Auto)
551 }
552 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
553 Some(mistral::ToolChoice::Any)
554 }
555 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
556 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
557 _ => None,
558 },
559 parallel_tool_calls: if !request.tools.is_empty() {
560 Some(false)
561 } else {
562 None
563 },
564 tools: request
565 .tools
566 .into_iter()
567 .map(|tool| mistral::ToolDefinition::Function {
568 function: mistral::FunctionDefinition {
569 name: tool.name,
570 description: Some(tool.description),
571 parameters: Some(tool.input_schema),
572 },
573 })
574 .collect(),
575 },
576 request.thread_id,
577 )
578}
579
580pub struct MistralEventMapper {
581 tool_calls_by_index: HashMap<usize, RawToolCall>,
582}
583
584impl MistralEventMapper {
585 pub fn new() -> Self {
586 Self {
587 tool_calls_by_index: HashMap::default(),
588 }
589 }
590
591 pub fn map_stream(
592 mut self,
593 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
594 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
595 {
596 events.flat_map(move |event| {
597 futures::stream::iter(match event {
598 Ok(event) => self.map_event(event),
599 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
600 })
601 })
602 }
603
604 pub fn map_event(
605 &mut self,
606 event: mistral::StreamResponse,
607 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
608 let Some(choice) = event.choices.first() else {
609 return vec![Err(LanguageModelCompletionError::from(anyhow!(
610 "Response contained no choices"
611 )))];
612 };
613
614 let mut events = Vec::new();
615 if let Some(content) = choice.delta.content.as_ref() {
616 match content {
617 mistral::MessageContentDelta::Text(text) => {
618 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
619 }
620 mistral::MessageContentDelta::Parts(parts) => {
621 for part in parts {
622 match part {
623 mistral::MessagePart::Text { text } => {
624 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
625 }
626 mistral::MessagePart::Thinking { thinking } => {
627 for tp in thinking.iter().cloned() {
628 match tp {
629 mistral::ThinkingPart::Text { text } => {
630 events.push(Ok(
631 LanguageModelCompletionEvent::Thinking {
632 text,
633 signature: None,
634 },
635 ));
636 }
637 }
638 }
639 }
640 mistral::MessagePart::ImageUrl { .. } => {
641 // We currently don't emit a separate event for images in responses.
642 }
643 }
644 }
645 }
646 }
647 }
648
649 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
650 for tool_call in tool_calls {
651 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
652
653 if let Some(tool_id) = tool_call.id.clone()
654 && !tool_id.is_empty()
655 {
656 entry.id = tool_id;
657 }
658
659 if let Some(function) = tool_call.function.as_ref() {
660 if let Some(name) = function.name.clone()
661 && !name.is_empty()
662 {
663 entry.name = name;
664 }
665
666 if let Some(arguments) = function.arguments.clone() {
667 entry.arguments.push_str(&arguments);
668 }
669 }
670
671 if !entry.id.is_empty() && !entry.name.is_empty() {
672 if let Ok(input) = serde_json::from_str::<serde_json::Value>(
673 &fix_streamed_json(&entry.arguments),
674 ) {
675 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
676 LanguageModelToolUse {
677 id: entry.id.clone().into(),
678 name: entry.name.as_str().into(),
679 is_input_complete: false,
680 input,
681 raw_input: entry.arguments.clone(),
682 thought_signature: None,
683 },
684 )));
685 }
686 }
687 }
688 }
689
690 if let Some(usage) = event.usage {
691 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
692 input_tokens: usage.prompt_tokens,
693 output_tokens: usage.completion_tokens,
694 cache_creation_input_tokens: 0,
695 cache_read_input_tokens: 0,
696 })));
697 }
698
699 if let Some(finish_reason) = choice.finish_reason.as_deref() {
700 match finish_reason {
701 "stop" => {
702 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
703 }
704 "tool_calls" => {
705 events.extend(self.process_tool_calls());
706 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
707 }
708 unexpected => {
709 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
710 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
711 }
712 }
713 }
714
715 events
716 }
717
718 fn process_tool_calls(
719 &mut self,
720 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
721 let mut results = Vec::new();
722
723 for (_, tool_call) in self.tool_calls_by_index.drain() {
724 if tool_call.id.is_empty() || tool_call.name.is_empty() {
725 results.push(Err(LanguageModelCompletionError::from(anyhow!(
726 "Received incomplete tool call: missing id or name"
727 ))));
728 continue;
729 }
730
731 match parse_tool_arguments(&tool_call.arguments) {
732 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
733 LanguageModelToolUse {
734 id: tool_call.id.into(),
735 name: tool_call.name.into(),
736 is_input_complete: true,
737 input,
738 raw_input: tool_call.arguments,
739 thought_signature: None,
740 },
741 ))),
742 Err(error) => {
743 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
744 id: tool_call.id.into(),
745 tool_name: tool_call.name.into(),
746 raw_input: tool_call.arguments.into(),
747 json_parse_error: error.to_string(),
748 }))
749 }
750 }
751 }
752
753 results
754 }
755}
756
757#[derive(Default)]
758struct RawToolCall {
759 id: String,
760 name: String,
761 arguments: String,
762}
763
764struct ConfigurationView {
765 api_key_editor: Entity<InputField>,
766 state: Entity<State>,
767 load_credentials_task: Option<Task<()>>,
768}
769
770impl ConfigurationView {
771 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
772 let api_key_editor =
773 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
774
775 cx.observe(&state, |_, _, cx| {
776 cx.notify();
777 })
778 .detach();
779
780 let load_credentials_task = Some(cx.spawn_in(window, {
781 let state = state.clone();
782 async move |this, cx| {
783 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
784 // We don't log an error, because "not signed in" is also an error.
785 let _ = task.await;
786 }
787
788 this.update(cx, |this, cx| {
789 this.load_credentials_task = None;
790 cx.notify();
791 })
792 .log_err();
793 }
794 }));
795
796 Self {
797 api_key_editor,
798 state,
799 load_credentials_task,
800 }
801 }
802
803 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
804 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
805 if api_key.is_empty() {
806 return;
807 }
808
809 // url changes can cause the editor to be displayed again
810 self.api_key_editor
811 .update(cx, |editor, cx| editor.set_text("", window, cx));
812
813 let state = self.state.clone();
814 cx.spawn_in(window, async move |_, cx| {
815 state
816 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
817 .await
818 })
819 .detach_and_log_err(cx);
820 }
821
822 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
823 self.api_key_editor
824 .update(cx, |editor, cx| editor.set_text("", window, cx));
825
826 let state = self.state.clone();
827 cx.spawn_in(window, async move |_, cx| {
828 state
829 .update(cx, |state, cx| state.set_api_key(None, cx))
830 .await
831 })
832 .detach_and_log_err(cx);
833 }
834
835 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
836 !self.state.read(cx).is_authenticated()
837 }
838}
839
840impl Render for ConfigurationView {
841 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
842 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
843 let configured_card_label = if env_var_set {
844 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
845 } else {
846 let api_url = MistralLanguageModelProvider::api_url(cx);
847 if api_url == MISTRAL_API_URL {
848 "API key configured".to_string()
849 } else {
850 format!("API key configured for {}", api_url)
851 }
852 };
853
854 if self.load_credentials_task.is_some() {
855 div().child(Label::new("Loading credentials...")).into_any()
856 } else if self.should_render_api_key_editor(cx) {
857 v_flex()
858 .size_full()
859 .on_action(cx.listener(Self::save_api_key))
860 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
861 .child(
862 List::new()
863 .child(
864 ListBulletItem::new("")
865 .child(Label::new("Create one by visiting"))
866 .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
867 )
868 .child(
869 ListBulletItem::new("Ensure your Mistral account has credits")
870 )
871 .child(
872 ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
873 ),
874 )
875 .child(self.api_key_editor.clone())
876 .child(
877 Label::new(
878 format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
879 )
880 .size(LabelSize::Small).color(Color::Muted),
881 )
882 .into_any()
883 } else {
884 v_flex()
885 .size_full()
886 .gap_1()
887 .child(
888 ConfiguredApiCard::new(configured_card_label)
889 .disabled(env_var_set)
890 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
891 .when(env_var_set, |this| {
892 this.tooltip_label(format!(
893 "To reset your API key, \
894 unset the {API_KEY_ENV_VAR_NAME} environment variable."
895 ))
896 }),
897 )
898 .into_any()
899 }
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
907
908 #[test]
909 fn test_into_mistral_basic_conversion() {
910 let request = LanguageModelRequest {
911 messages: vec![
912 LanguageModelRequestMessage {
913 role: Role::System,
914 content: vec![MessageContent::Text("System prompt".into())],
915 cache: false,
916 reasoning_details: None,
917 },
918 LanguageModelRequestMessage {
919 role: Role::User,
920 content: vec![MessageContent::Text("Hello".into())],
921 cache: false,
922 reasoning_details: None,
923 },
924 // should skip empty assistant messages
925 LanguageModelRequestMessage {
926 role: Role::Assistant,
927 content: vec![MessageContent::Text("".into())],
928 cache: false,
929 reasoning_details: None,
930 },
931 ],
932 temperature: Some(0.5),
933 tools: vec![],
934 tool_choice: None,
935 thread_id: Some("abcdef".into()),
936 prompt_id: None,
937 intent: None,
938 stop: vec![],
939 thinking_allowed: true,
940 thinking_effort: None,
941 speed: Default::default(),
942 };
943
944 let (mistral_request, affinity) =
945 into_mistral(request, mistral::Model::MistralSmallLatest, None);
946
947 assert_eq!(mistral_request.model, "mistral-small-latest");
948 assert_eq!(mistral_request.temperature, Some(0.5));
949 assert_eq!(mistral_request.messages.len(), 2);
950 assert!(mistral_request.stream);
951 assert_eq!(affinity, Some("abcdef".into()));
952 }
953
954 #[test]
955 fn test_into_mistral_with_image() {
956 let request = LanguageModelRequest {
957 messages: vec![LanguageModelRequestMessage {
958 role: Role::User,
959 content: vec![
960 MessageContent::Text("What's in this image?".into()),
961 MessageContent::Image(LanguageModelImage {
962 source: "base64data".into(),
963 size: None,
964 }),
965 ],
966 cache: false,
967 reasoning_details: None,
968 }],
969 tools: vec![],
970 tool_choice: None,
971 temperature: None,
972 thread_id: None,
973 prompt_id: None,
974 intent: None,
975 stop: vec![],
976 thinking_allowed: true,
977 thinking_effort: None,
978 speed: None,
979 };
980
981 let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
982
983 assert_eq!(mistral_request.messages.len(), 1);
984 assert!(matches!(
985 &mistral_request.messages[0],
986 mistral::RequestMessage::User {
987 content: mistral::MessageContent::Multipart { .. }
988 }
989 ));
990
991 if let mistral::RequestMessage::User {
992 content: mistral::MessageContent::Multipart { content },
993 } = &mistral_request.messages[0]
994 {
995 assert_eq!(content.len(), 2);
996 assert!(matches!(
997 &content[0],
998 mistral::MessagePart::Text { text } if text == "What's in this image?"
999 ));
1000 assert!(matches!(
1001 &content[1],
1002 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1003 ));
1004 }
1005 }
1006}