1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
4use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
5use http_client::HttpClient;
6use language_model::{
7 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
9 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
10 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
11 RateLimiter, Role, StopReason, TokenUsage,
12};
13use mistral::{MISTRAL_API_URL, StreamResponse};
14pub use settings::MistralAvailableModel as AvailableModel;
15use settings::{Settings, SettingsStore};
16use std::collections::HashMap;
17use std::pin::Pin;
18use std::str::FromStr;
19use std::sync::{Arc, LazyLock};
20use strum::IntoEnumIterator;
21use ui::{Icon, IconName, List, Tooltip, prelude::*};
22use ui_input::SingleLineInput;
23use util::{ResultExt, truncate_and_trailoff};
24use zed_env_vars::{EnvVar, env_var};
25
26use crate::{api_key::ApiKeyState, ui::InstructionListItem};
27
28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
30
31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
33
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct MistralSettings {
36 pub api_url: String,
37 pub available_models: Vec<AvailableModel>,
38}
39
40pub struct MistralLanguageModelProvider {
41 http_client: Arc<dyn HttpClient>,
42 state: Entity<State>,
43}
44
45pub struct State {
46 api_key_state: ApiKeyState,
47}
48
49impl State {
50 fn is_authenticated(&self) -> bool {
51 self.api_key_state.has_key()
52 }
53
54 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
55 let api_url = MistralLanguageModelProvider::api_url(cx);
56 self.api_key_state
57 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
58 }
59
60 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
61 let api_url = MistralLanguageModelProvider::api_url(cx);
62 self.api_key_state.load_if_needed(
63 api_url,
64 &API_KEY_ENV_VAR,
65 |this| &mut this.api_key_state,
66 cx,
67 )
68 }
69}
70
71impl MistralLanguageModelProvider {
72 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
73 let state = cx.new(|cx| {
74 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
75 let api_url = Self::api_url(cx);
76 this.api_key_state.handle_url_change(
77 api_url,
78 &API_KEY_ENV_VAR,
79 |this| &mut this.api_key_state,
80 cx,
81 );
82 cx.notify();
83 })
84 .detach();
85 State {
86 api_key_state: ApiKeyState::new(Self::api_url(cx)),
87 }
88 });
89
90 Self { http_client, state }
91 }
92
93 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
94 Arc::new(MistralLanguageModel {
95 id: LanguageModelId::from(model.id().to_string()),
96 model,
97 state: self.state.clone(),
98 http_client: self.http_client.clone(),
99 request_limiter: RateLimiter::new(4),
100 })
101 }
102
103 fn settings(cx: &App) -> &MistralSettings {
104 &crate::AllLanguageModelSettings::get_global(cx).mistral
105 }
106
107 fn api_url(cx: &App) -> SharedString {
108 let api_url = &Self::settings(cx).api_url;
109 if api_url.is_empty() {
110 mistral::MISTRAL_API_URL.into()
111 } else {
112 SharedString::new(api_url.as_str())
113 }
114 }
115}
116
117impl LanguageModelProviderState for MistralLanguageModelProvider {
118 type ObservableEntity = State;
119
120 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
121 Some(self.state.clone())
122 }
123}
124
125impl LanguageModelProvider for MistralLanguageModelProvider {
126 fn id(&self) -> LanguageModelProviderId {
127 PROVIDER_ID
128 }
129
130 fn name(&self) -> LanguageModelProviderName {
131 PROVIDER_NAME
132 }
133
134 fn icon(&self) -> IconName {
135 IconName::AiMistral
136 }
137
138 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
139 Some(self.create_language_model(mistral::Model::default()))
140 }
141
142 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
143 Some(self.create_language_model(mistral::Model::default_fast()))
144 }
145
146 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
147 let mut models = BTreeMap::default();
148
149 // Add base models from mistral::Model::iter()
150 for model in mistral::Model::iter() {
151 if !matches!(model, mistral::Model::Custom { .. }) {
152 models.insert(model.id().to_string(), model);
153 }
154 }
155
156 // Override with available models from settings
157 for model in &Self::settings(cx).available_models {
158 models.insert(
159 model.name.clone(),
160 mistral::Model::Custom {
161 name: model.name.clone(),
162 display_name: model.display_name.clone(),
163 max_tokens: model.max_tokens,
164 max_output_tokens: model.max_output_tokens,
165 max_completion_tokens: model.max_completion_tokens,
166 supports_tools: model.supports_tools,
167 supports_images: model.supports_images,
168 supports_thinking: model.supports_thinking,
169 },
170 );
171 }
172
173 models
174 .into_values()
175 .map(|model| {
176 Arc::new(MistralLanguageModel {
177 id: LanguageModelId::from(model.id().to_string()),
178 model,
179 state: self.state.clone(),
180 http_client: self.http_client.clone(),
181 request_limiter: RateLimiter::new(4),
182 }) as Arc<dyn LanguageModel>
183 })
184 .collect()
185 }
186
187 fn is_authenticated(&self, cx: &App) -> bool {
188 self.state.read(cx).is_authenticated()
189 }
190
191 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
192 self.state.update(cx, |state, cx| state.authenticate(cx))
193 }
194
195 fn configuration_view(
196 &self,
197 _target_agent: language_model::ConfigurationViewTargetAgent,
198 window: &mut Window,
199 cx: &mut App,
200 ) -> AnyView {
201 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
202 .into()
203 }
204
205 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
206 self.state
207 .update(cx, |state, cx| state.set_api_key(None, cx))
208 }
209}
210
211pub struct MistralLanguageModel {
212 id: LanguageModelId,
213 model: mistral::Model,
214 state: Entity<State>,
215 http_client: Arc<dyn HttpClient>,
216 request_limiter: RateLimiter,
217}
218
219impl MistralLanguageModel {
220 fn stream_completion(
221 &self,
222 request: mistral::Request,
223 cx: &AsyncApp,
224 ) -> BoxFuture<
225 'static,
226 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
227 > {
228 let http_client = self.http_client.clone();
229
230 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
231 let api_url = MistralLanguageModelProvider::api_url(cx);
232 (state.api_key_state.key(&api_url), api_url)
233 }) else {
234 return future::ready(Err(anyhow!("App state dropped"))).boxed();
235 };
236
237 let future = self.request_limiter.stream(async move {
238 let Some(api_key) = api_key else {
239 return Err(LanguageModelCompletionError::NoApiKey {
240 provider: PROVIDER_NAME,
241 });
242 };
243 let request =
244 mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
245 let response = request.await?;
246 Ok(response)
247 });
248
249 async move { Ok(future.await?.boxed()) }.boxed()
250 }
251}
252
253impl LanguageModel for MistralLanguageModel {
254 fn id(&self) -> LanguageModelId {
255 self.id.clone()
256 }
257
258 fn name(&self) -> LanguageModelName {
259 LanguageModelName::from(self.model.display_name().to_string())
260 }
261
262 fn provider_id(&self) -> LanguageModelProviderId {
263 PROVIDER_ID
264 }
265
266 fn provider_name(&self) -> LanguageModelProviderName {
267 PROVIDER_NAME
268 }
269
270 fn supports_tools(&self) -> bool {
271 self.model.supports_tools()
272 }
273
274 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
275 self.model.supports_tools()
276 }
277
278 fn supports_images(&self) -> bool {
279 self.model.supports_images()
280 }
281
282 fn telemetry_id(&self) -> String {
283 format!("mistral/{}", self.model.id())
284 }
285
286 fn max_token_count(&self) -> u64 {
287 self.model.max_token_count()
288 }
289
290 fn max_output_tokens(&self) -> Option<u64> {
291 self.model.max_output_tokens()
292 }
293
294 fn count_tokens(
295 &self,
296 request: LanguageModelRequest,
297 cx: &App,
298 ) -> BoxFuture<'static, Result<u64>> {
299 cx.background_spawn(async move {
300 let messages = request
301 .messages
302 .into_iter()
303 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
304 role: match message.role {
305 Role::User => "user".into(),
306 Role::Assistant => "assistant".into(),
307 Role::System => "system".into(),
308 },
309 content: Some(message.string_contents()),
310 name: None,
311 function_call: None,
312 })
313 .collect::<Vec<_>>();
314
315 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
316 })
317 .boxed()
318 }
319
320 fn stream_completion(
321 &self,
322 request: LanguageModelRequest,
323 cx: &AsyncApp,
324 ) -> BoxFuture<
325 'static,
326 Result<
327 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
328 LanguageModelCompletionError,
329 >,
330 > {
331 let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
332 let stream = self.stream_completion(request, cx);
333
334 async move {
335 let stream = stream.await?;
336 let mapper = MistralEventMapper::new();
337 Ok(mapper.map_stream(stream).boxed())
338 }
339 .boxed()
340 }
341}
342
343pub fn into_mistral(
344 request: LanguageModelRequest,
345 model: mistral::Model,
346 max_output_tokens: Option<u64>,
347) -> mistral::Request {
348 let stream = true;
349
350 let mut messages = Vec::new();
351 for message in &request.messages {
352 match message.role {
353 Role::User => {
354 let mut message_content = mistral::MessageContent::empty();
355 for content in &message.content {
356 match content {
357 MessageContent::Text(text) => {
358 message_content
359 .push_part(mistral::MessagePart::Text { text: text.clone() });
360 }
361 MessageContent::Image(image_content) => {
362 if model.supports_images() {
363 message_content.push_part(mistral::MessagePart::ImageUrl {
364 image_url: image_content.to_base64_url(),
365 });
366 }
367 }
368 MessageContent::Thinking { text, .. } => {
369 if model.supports_thinking() {
370 message_content.push_part(mistral::MessagePart::Thinking {
371 thinking: vec![mistral::ThinkingPart::Text {
372 text: text.clone(),
373 }],
374 });
375 }
376 }
377 MessageContent::RedactedThinking(_) => {}
378 MessageContent::ToolUse(_) => {
379 // Tool use is not supported in User messages for Mistral
380 }
381 MessageContent::ToolResult(tool_result) => {
382 let tool_content = match &tool_result.content {
383 LanguageModelToolResultContent::Text(text) => text.to_string(),
384 LanguageModelToolResultContent::Image(_) => {
385 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
386 }
387 };
388 messages.push(mistral::RequestMessage::Tool {
389 content: tool_content,
390 tool_call_id: tool_result.tool_use_id.to_string(),
391 });
392 }
393 }
394 }
395 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
396 {
397 messages.push(mistral::RequestMessage::User {
398 content: message_content,
399 });
400 }
401 }
402 Role::Assistant => {
403 for content in &message.content {
404 match content {
405 MessageContent::Text(text) => {
406 messages.push(mistral::RequestMessage::Assistant {
407 content: Some(mistral::MessageContent::Plain {
408 content: text.clone(),
409 }),
410 tool_calls: Vec::new(),
411 });
412 }
413 MessageContent::Thinking { text, .. } => {
414 if model.supports_thinking() {
415 messages.push(mistral::RequestMessage::Assistant {
416 content: Some(mistral::MessageContent::Multipart {
417 content: vec![mistral::MessagePart::Thinking {
418 thinking: vec![mistral::ThinkingPart::Text {
419 text: text.clone(),
420 }],
421 }],
422 }),
423 tool_calls: Vec::new(),
424 });
425 }
426 }
427 MessageContent::RedactedThinking(_) => {}
428 MessageContent::Image(_) => {}
429 MessageContent::ToolUse(tool_use) => {
430 let tool_call = mistral::ToolCall {
431 id: tool_use.id.to_string(),
432 content: mistral::ToolCallContent::Function {
433 function: mistral::FunctionContent {
434 name: tool_use.name.to_string(),
435 arguments: serde_json::to_string(&tool_use.input)
436 .unwrap_or_default(),
437 },
438 },
439 };
440
441 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
442 messages.last_mut()
443 {
444 tool_calls.push(tool_call);
445 } else {
446 messages.push(mistral::RequestMessage::Assistant {
447 content: None,
448 tool_calls: vec![tool_call],
449 });
450 }
451 }
452 MessageContent::ToolResult(_) => {
453 // Tool results are not supported in Assistant messages
454 }
455 }
456 }
457 }
458 Role::System => {
459 for content in &message.content {
460 match content {
461 MessageContent::Text(text) => {
462 messages.push(mistral::RequestMessage::System {
463 content: mistral::MessageContent::Plain {
464 content: text.clone(),
465 },
466 });
467 }
468 MessageContent::Thinking { text, .. } => {
469 if model.supports_thinking() {
470 messages.push(mistral::RequestMessage::System {
471 content: mistral::MessageContent::Multipart {
472 content: vec![mistral::MessagePart::Thinking {
473 thinking: vec![mistral::ThinkingPart::Text {
474 text: text.clone(),
475 }],
476 }],
477 },
478 });
479 }
480 }
481 MessageContent::RedactedThinking(_) => {}
482 MessageContent::Image(_)
483 | MessageContent::ToolUse(_)
484 | MessageContent::ToolResult(_) => {
485 // Images and tools are not supported in System messages
486 }
487 }
488 }
489 }
490 }
491 }
492
493 mistral::Request {
494 model: model.id().to_string(),
495 messages,
496 stream,
497 max_tokens: max_output_tokens,
498 temperature: request.temperature,
499 response_format: None,
500 tool_choice: match request.tool_choice {
501 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
502 Some(mistral::ToolChoice::Auto)
503 }
504 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
505 Some(mistral::ToolChoice::Any)
506 }
507 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
508 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
509 _ => None,
510 },
511 parallel_tool_calls: if !request.tools.is_empty() {
512 Some(false)
513 } else {
514 None
515 },
516 tools: request
517 .tools
518 .into_iter()
519 .map(|tool| mistral::ToolDefinition::Function {
520 function: mistral::FunctionDefinition {
521 name: tool.name,
522 description: Some(tool.description),
523 parameters: Some(tool.input_schema),
524 },
525 })
526 .collect(),
527 }
528}
529
530pub struct MistralEventMapper {
531 tool_calls_by_index: HashMap<usize, RawToolCall>,
532}
533
534impl MistralEventMapper {
535 pub fn new() -> Self {
536 Self {
537 tool_calls_by_index: HashMap::default(),
538 }
539 }
540
541 pub fn map_stream(
542 mut self,
543 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
544 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
545 {
546 events.flat_map(move |event| {
547 futures::stream::iter(match event {
548 Ok(event) => self.map_event(event),
549 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
550 })
551 })
552 }
553
554 pub fn map_event(
555 &mut self,
556 event: mistral::StreamResponse,
557 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
558 let Some(choice) = event.choices.first() else {
559 return vec![Err(LanguageModelCompletionError::from(anyhow!(
560 "Response contained no choices"
561 )))];
562 };
563
564 let mut events = Vec::new();
565 if let Some(content) = choice.delta.content.as_ref() {
566 match content {
567 mistral::MessageContentDelta::Text(text) => {
568 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
569 }
570 mistral::MessageContentDelta::Parts(parts) => {
571 for part in parts {
572 match part {
573 mistral::MessagePart::Text { text } => {
574 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
575 }
576 mistral::MessagePart::Thinking { thinking } => {
577 for tp in thinking.iter().cloned() {
578 match tp {
579 mistral::ThinkingPart::Text { text } => {
580 events.push(Ok(
581 LanguageModelCompletionEvent::Thinking {
582 text,
583 signature: None,
584 },
585 ));
586 }
587 }
588 }
589 }
590 mistral::MessagePart::ImageUrl { .. } => {
591 // We currently don't emit a separate event for images in responses.
592 }
593 }
594 }
595 }
596 }
597 }
598
599 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
600 for tool_call in tool_calls {
601 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
602
603 if let Some(tool_id) = tool_call.id.clone() {
604 entry.id = tool_id;
605 }
606
607 if let Some(function) = tool_call.function.as_ref() {
608 if let Some(name) = function.name.clone() {
609 entry.name = name;
610 }
611
612 if let Some(arguments) = function.arguments.clone() {
613 entry.arguments.push_str(&arguments);
614 }
615 }
616 }
617 }
618
619 if let Some(usage) = event.usage {
620 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
621 input_tokens: usage.prompt_tokens,
622 output_tokens: usage.completion_tokens,
623 cache_creation_input_tokens: 0,
624 cache_read_input_tokens: 0,
625 })));
626 }
627
628 if let Some(finish_reason) = choice.finish_reason.as_deref() {
629 match finish_reason {
630 "stop" => {
631 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
632 }
633 "tool_calls" => {
634 events.extend(self.process_tool_calls());
635 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
636 }
637 unexpected => {
638 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
639 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
640 }
641 }
642 }
643
644 events
645 }
646
647 fn process_tool_calls(
648 &mut self,
649 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
650 let mut results = Vec::new();
651
652 for (_, tool_call) in self.tool_calls_by_index.drain() {
653 if tool_call.id.is_empty() || tool_call.name.is_empty() {
654 results.push(Err(LanguageModelCompletionError::from(anyhow!(
655 "Received incomplete tool call: missing id or name"
656 ))));
657 continue;
658 }
659
660 match serde_json::Value::from_str(&tool_call.arguments) {
661 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
662 LanguageModelToolUse {
663 id: tool_call.id.into(),
664 name: tool_call.name.into(),
665 is_input_complete: true,
666 input,
667 raw_input: tool_call.arguments,
668 },
669 ))),
670 Err(error) => {
671 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
672 id: tool_call.id.into(),
673 tool_name: tool_call.name.into(),
674 raw_input: tool_call.arguments.into(),
675 json_parse_error: error.to_string(),
676 }))
677 }
678 }
679 }
680
681 results
682 }
683}
684
685#[derive(Default)]
686struct RawToolCall {
687 id: String,
688 name: String,
689 arguments: String,
690}
691
692struct ConfigurationView {
693 api_key_editor: Entity<SingleLineInput>,
694 state: Entity<State>,
695 load_credentials_task: Option<Task<()>>,
696}
697
698impl ConfigurationView {
699 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
700 let api_key_editor =
701 cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
702
703 cx.observe(&state, |_, _, cx| {
704 cx.notify();
705 })
706 .detach();
707
708 let load_credentials_task = Some(cx.spawn_in(window, {
709 let state = state.clone();
710 async move |this, cx| {
711 if let Some(task) = state
712 .update(cx, |state, cx| state.authenticate(cx))
713 .log_err()
714 {
715 // We don't log an error, because "not signed in" is also an error.
716 let _ = task.await;
717 }
718
719 this.update(cx, |this, cx| {
720 this.load_credentials_task = None;
721 cx.notify();
722 })
723 .log_err();
724 }
725 }));
726
727 Self {
728 api_key_editor,
729 state,
730 load_credentials_task,
731 }
732 }
733
734 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
735 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
736 if api_key.is_empty() {
737 return;
738 }
739
740 // url changes can cause the editor to be displayed again
741 self.api_key_editor
742 .update(cx, |editor, cx| editor.set_text("", window, cx));
743
744 let state = self.state.clone();
745 cx.spawn_in(window, async move |_, cx| {
746 state
747 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
748 .await
749 })
750 .detach_and_log_err(cx);
751 }
752
753 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
754 self.api_key_editor
755 .update(cx, |editor, cx| editor.set_text("", window, cx));
756
757 let state = self.state.clone();
758 cx.spawn_in(window, async move |_, cx| {
759 state
760 .update(cx, |state, cx| state.set_api_key(None, cx))?
761 .await
762 })
763 .detach_and_log_err(cx);
764 }
765
766 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
767 !self.state.read(cx).is_authenticated()
768 }
769}
770
771impl Render for ConfigurationView {
772 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
773 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
774
775 if self.load_credentials_task.is_some() {
776 div().child(Label::new("Loading credentials...")).into_any()
777 } else if self.should_render_editor(cx) {
778 v_flex()
779 .size_full()
780 .on_action(cx.listener(Self::save_api_key))
781 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
782 .child(
783 List::new()
784 .child(InstructionListItem::new(
785 "Create one by visiting",
786 Some("Mistral's console"),
787 Some("https://console.mistral.ai/api-keys"),
788 ))
789 .child(InstructionListItem::text_only(
790 "Ensure your Mistral account has credits",
791 ))
792 .child(InstructionListItem::text_only(
793 "Paste your API key below and hit enter to start using the assistant",
794 )),
795 )
796 .child(self.api_key_editor.clone())
797 .child(
798 Label::new(
799 format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
800 )
801 .size(LabelSize::Small).color(Color::Muted),
802 )
803 .into_any()
804 } else {
805 h_flex()
806 .mt_1()
807 .p_1()
808 .justify_between()
809 .rounded_md()
810 .border_1()
811 .border_color(cx.theme().colors().border)
812 .bg(cx.theme().colors().background)
813 .child(
814 h_flex()
815 .gap_1()
816 .child(Icon::new(IconName::Check).color(Color::Success))
817 .child(Label::new(if env_var_set {
818 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
819 } else {
820 let api_url = MistralLanguageModelProvider::api_url(cx);
821 if api_url == MISTRAL_API_URL {
822 "API key configured".to_string()
823 } else {
824 format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
825 }
826 })),
827 )
828 .child(
829 Button::new("reset-key", "Reset Key")
830 .label_size(LabelSize::Small)
831 .icon(Some(IconName::Trash))
832 .icon_size(IconSize::Small)
833 .icon_position(IconPosition::Start)
834 .disabled(env_var_set)
835 .when(env_var_set, |this| {
836 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
837 })
838 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
839 )
840 .into_any()
841 }
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
849
850 #[test]
851 fn test_into_mistral_basic_conversion() {
852 let request = LanguageModelRequest {
853 messages: vec![
854 LanguageModelRequestMessage {
855 role: Role::System,
856 content: vec![MessageContent::Text("System prompt".into())],
857 cache: false,
858 },
859 LanguageModelRequestMessage {
860 role: Role::User,
861 content: vec![MessageContent::Text("Hello".into())],
862 cache: false,
863 },
864 ],
865 temperature: Some(0.5),
866 tools: vec![],
867 tool_choice: None,
868 thread_id: None,
869 prompt_id: None,
870 intent: None,
871 mode: None,
872 stop: vec![],
873 thinking_allowed: true,
874 };
875
876 let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
877
878 assert_eq!(mistral_request.model, "mistral-small-latest");
879 assert_eq!(mistral_request.temperature, Some(0.5));
880 assert_eq!(mistral_request.messages.len(), 2);
881 assert!(mistral_request.stream);
882 }
883
884 #[test]
885 fn test_into_mistral_with_image() {
886 let request = LanguageModelRequest {
887 messages: vec![LanguageModelRequestMessage {
888 role: Role::User,
889 content: vec![
890 MessageContent::Text("What's in this image?".into()),
891 MessageContent::Image(LanguageModelImage {
892 source: "base64data".into(),
893 size: Default::default(),
894 }),
895 ],
896 cache: false,
897 }],
898 tools: vec![],
899 tool_choice: None,
900 temperature: None,
901 thread_id: None,
902 prompt_id: None,
903 intent: None,
904 mode: None,
905 stop: vec![],
906 thinking_allowed: true,
907 };
908
909 let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
910
911 assert_eq!(mistral_request.messages.len(), 1);
912 assert!(matches!(
913 &mistral_request.messages[0],
914 mistral::RequestMessage::User {
915 content: mistral::MessageContent::Multipart { .. }
916 }
917 ));
918
919 if let mistral::RequestMessage::User {
920 content: mistral::MessageContent::Multipart { content },
921 } = &mistral_request.messages[0]
922 {
923 assert_eq!(content.len(), 2);
924 assert!(matches!(
925 &content[0],
926 mistral::MessagePart::Text { text } if text == "What's in this image?"
927 ));
928 assert!(matches!(
929 &content[1],
930 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
931 ));
932 }
933 }
934}