1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4
5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
10 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
13 LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
14};
15pub use mistral::{MISTRAL_API_URL, StreamResponse};
16pub use settings::MistralAvailableModel as AvailableModel;
17use settings::{Settings, SettingsStore};
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
23use ui_input::InputField;
24use util::ResultExt;
25
26use language_model::util::{fix_streamed_json, parse_tool_arguments};
27
28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
30
31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
33
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct MistralSettings {
36 pub api_url: String,
37 pub available_models: Vec<AvailableModel>,
38}
39
40pub struct MistralLanguageModelProvider {
41 http_client: Arc<dyn HttpClient>,
42 pub state: Entity<State>,
43}
44
45pub struct State {
46 api_key_state: ApiKeyState,
47 credentials_provider: Arc<dyn CredentialsProvider>,
48}
49
50impl State {
51 fn is_authenticated(&self) -> bool {
52 self.api_key_state.has_key()
53 }
54
55 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
56 let credentials_provider = self.credentials_provider.clone();
57 let api_url = MistralLanguageModelProvider::api_url(cx);
58 self.api_key_state.store(
59 api_url,
60 api_key,
61 |this| &mut this.api_key_state,
62 credentials_provider,
63 cx,
64 )
65 }
66
67 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
68 let credentials_provider = self.credentials_provider.clone();
69 let api_url = MistralLanguageModelProvider::api_url(cx);
70 self.api_key_state.load_if_needed(
71 api_url,
72 |this| &mut this.api_key_state,
73 credentials_provider,
74 cx,
75 )
76 }
77}
78
79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
80
81impl Global for GlobalMistralLanguageModelProvider {}
82
83impl MistralLanguageModelProvider {
84 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
85 cx.try_global::<GlobalMistralLanguageModelProvider>()
86 .map(|this| &this.0)
87 }
88
89 pub fn global(
90 http_client: Arc<dyn HttpClient>,
91 credentials_provider: Arc<dyn CredentialsProvider>,
92 cx: &mut App,
93 ) -> Arc<Self> {
94 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
95 return this.0.clone();
96 }
97 let state = cx.new(|cx| {
98 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
99 let credentials_provider = this.credentials_provider.clone();
100 let api_url = Self::api_url(cx);
101 this.api_key_state.handle_url_change(
102 api_url,
103 |this| &mut this.api_key_state,
104 credentials_provider,
105 cx,
106 );
107 cx.notify();
108 })
109 .detach();
110 State {
111 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
112 credentials_provider,
113 }
114 });
115
116 let this = Arc::new(Self { http_client, state });
117 cx.set_global(GlobalMistralLanguageModelProvider(this));
118 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
119 }
120
121 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
122 Arc::new(MistralLanguageModel {
123 id: LanguageModelId::from(model.id().to_string()),
124 model,
125 state: self.state.clone(),
126 http_client: self.http_client.clone(),
127 request_limiter: RateLimiter::new(4),
128 })
129 }
130
131 fn settings(cx: &App) -> &MistralSettings {
132 &crate::AllLanguageModelSettings::get_global(cx).mistral
133 }
134
135 pub fn api_url(cx: &App) -> SharedString {
136 let api_url = &Self::settings(cx).api_url;
137 if api_url.is_empty() {
138 mistral::MISTRAL_API_URL.into()
139 } else {
140 SharedString::new(api_url.as_str())
141 }
142 }
143}
144
145impl LanguageModelProviderState for MistralLanguageModelProvider {
146 type ObservableEntity = State;
147
148 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
149 Some(self.state.clone())
150 }
151}
152
153impl LanguageModelProvider for MistralLanguageModelProvider {
154 fn id(&self) -> LanguageModelProviderId {
155 PROVIDER_ID
156 }
157
158 fn name(&self) -> LanguageModelProviderName {
159 PROVIDER_NAME
160 }
161
162 fn icon(&self) -> IconOrSvg {
163 IconOrSvg::Icon(IconName::AiMistral)
164 }
165
166 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
167 Some(self.create_language_model(mistral::Model::default()))
168 }
169
170 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
171 Some(self.create_language_model(mistral::Model::default_fast()))
172 }
173
174 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
175 let mut models = BTreeMap::default();
176
177 // Add base models from mistral::Model::iter()
178 for model in mistral::Model::iter() {
179 if !matches!(model, mistral::Model::Custom { .. }) {
180 models.insert(model.id().to_string(), model);
181 }
182 }
183
184 // Override with available models from settings
185 for model in &Self::settings(cx).available_models {
186 models.insert(
187 model.name.clone(),
188 mistral::Model::Custom {
189 name: model.name.clone(),
190 display_name: model.display_name.clone(),
191 max_tokens: model.max_tokens,
192 max_output_tokens: model.max_output_tokens,
193 max_completion_tokens: model.max_completion_tokens,
194 supports_tools: model.supports_tools,
195 supports_images: model.supports_images,
196 supports_thinking: model.supports_thinking,
197 },
198 );
199 }
200
201 models
202 .into_values()
203 .map(|model| {
204 Arc::new(MistralLanguageModel {
205 id: LanguageModelId::from(model.id().to_string()),
206 model,
207 state: self.state.clone(),
208 http_client: self.http_client.clone(),
209 request_limiter: RateLimiter::new(4),
210 }) as Arc<dyn LanguageModel>
211 })
212 .collect()
213 }
214
215 fn is_authenticated(&self, cx: &App) -> bool {
216 self.state.read(cx).is_authenticated()
217 }
218
219 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
220 self.state.update(cx, |state, cx| state.authenticate(cx))
221 }
222
223 fn configuration_view(
224 &self,
225 _target_agent: language_model::ConfigurationViewTargetAgent,
226 window: &mut Window,
227 cx: &mut App,
228 ) -> AnyView {
229 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
230 .into()
231 }
232
233 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
234 self.state
235 .update(cx, |state, cx| state.set_api_key(None, cx))
236 }
237}
238
239pub struct MistralLanguageModel {
240 id: LanguageModelId,
241 model: mistral::Model,
242 state: Entity<State>,
243 http_client: Arc<dyn HttpClient>,
244 request_limiter: RateLimiter,
245}
246
247impl MistralLanguageModel {
248 fn stream_completion(
249 &self,
250 request: mistral::Request,
251 affinity: Option<String>,
252 cx: &AsyncApp,
253 ) -> BoxFuture<
254 'static,
255 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
256 > {
257 let http_client = self.http_client.clone();
258
259 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
260 let api_url = MistralLanguageModelProvider::api_url(cx);
261 (state.api_key_state.key(&api_url), api_url)
262 });
263
264 let future = self.request_limiter.stream(async move {
265 let Some(api_key) = api_key else {
266 return Err(LanguageModelCompletionError::NoApiKey {
267 provider: PROVIDER_NAME,
268 });
269 };
270 let request = mistral::stream_completion(
271 http_client.as_ref(),
272 &api_url,
273 &api_key,
274 request,
275 affinity,
276 );
277 let response = request.await?;
278 Ok(response)
279 });
280
281 async move { Ok(future.await?.boxed()) }.boxed()
282 }
283}
284
285impl LanguageModel for MistralLanguageModel {
286 fn id(&self) -> LanguageModelId {
287 self.id.clone()
288 }
289
290 fn name(&self) -> LanguageModelName {
291 LanguageModelName::from(self.model.display_name().to_string())
292 }
293
294 fn provider_id(&self) -> LanguageModelProviderId {
295 PROVIDER_ID
296 }
297
298 fn provider_name(&self) -> LanguageModelProviderName {
299 PROVIDER_NAME
300 }
301
302 fn supports_tools(&self) -> bool {
303 self.model.supports_tools()
304 }
305
306 fn supports_streaming_tools(&self) -> bool {
307 true
308 }
309
310 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
311 self.model.supports_tools()
312 }
313
314 fn supports_images(&self) -> bool {
315 self.model.supports_images()
316 }
317
318 fn telemetry_id(&self) -> String {
319 format!("mistral/{}", self.model.id())
320 }
321
322 fn max_token_count(&self) -> u64 {
323 self.model.max_token_count()
324 }
325
326 fn max_output_tokens(&self) -> Option<u64> {
327 self.model.max_output_tokens()
328 }
329
330 fn stream_completion(
331 &self,
332 request: LanguageModelRequest,
333 cx: &AsyncApp,
334 ) -> BoxFuture<
335 'static,
336 Result<
337 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
338 LanguageModelCompletionError,
339 >,
340 > {
341 let (request, affinity) =
342 into_mistral(request, self.model.clone(), self.max_output_tokens());
343 let stream = self.stream_completion(request, affinity, cx);
344
345 async move {
346 let stream = stream.await?;
347 let mapper = MistralEventMapper::new();
348 Ok(mapper.map_stream(stream).boxed())
349 }
350 .boxed()
351 }
352}
353
354pub fn into_mistral(
355 request: LanguageModelRequest,
356 model: mistral::Model,
357 max_output_tokens: Option<u64>,
358) -> (mistral::Request, Option<String>) {
359 let stream = true;
360
361 let mut messages = Vec::new();
362 for message in &request.messages {
363 match message.role {
364 Role::User => {
365 let mut message_content = mistral::MessageContent::empty();
366 for content in &message.content {
367 match content {
368 MessageContent::Text(text) => {
369 message_content
370 .push_part(mistral::MessagePart::Text { text: text.clone() });
371 }
372 MessageContent::Image(image_content) => {
373 if model.supports_images() {
374 message_content.push_part(mistral::MessagePart::ImageUrl {
375 image_url: image_content.to_base64_url(),
376 });
377 }
378 }
379 MessageContent::Thinking { text, .. } => {
380 if model.supports_thinking() {
381 message_content.push_part(mistral::MessagePart::Thinking {
382 thinking: vec![mistral::ThinkingPart::Text {
383 text: text.clone(),
384 }],
385 });
386 }
387 }
388 MessageContent::RedactedThinking(_) => {}
389 MessageContent::ToolUse(_) => {
390 // Tool use is not supported in User messages for Mistral
391 }
392 MessageContent::ToolResult(tool_result) => {
393 let tool_content = match &tool_result.content {
394 LanguageModelToolResultContent::Text(text) => text.to_string(),
395 LanguageModelToolResultContent::Image(_) => {
396 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
397 }
398 };
399 messages.push(mistral::RequestMessage::Tool {
400 content: tool_content,
401 tool_call_id: tool_result.tool_use_id.to_string(),
402 });
403 }
404 }
405 }
406 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
407 {
408 messages.push(mistral::RequestMessage::User {
409 content: message_content,
410 });
411 }
412 }
413 Role::Assistant => {
414 for content in &message.content {
415 match content {
416 MessageContent::Text(text) if text.is_empty() => {
417 // Mistral API returns a 400 if there's neither content nor tool_calls
418 }
419 MessageContent::Text(text) => {
420 messages.push(mistral::RequestMessage::Assistant {
421 content: Some(mistral::MessageContent::Plain {
422 content: text.clone(),
423 }),
424 tool_calls: Vec::new(),
425 });
426 }
427 MessageContent::Thinking { text, .. } => {
428 if model.supports_thinking() {
429 messages.push(mistral::RequestMessage::Assistant {
430 content: Some(mistral::MessageContent::Multipart {
431 content: vec![mistral::MessagePart::Thinking {
432 thinking: vec![mistral::ThinkingPart::Text {
433 text: text.clone(),
434 }],
435 }],
436 }),
437 tool_calls: Vec::new(),
438 });
439 }
440 }
441 MessageContent::RedactedThinking(_) => {}
442 MessageContent::Image(_) => {}
443 MessageContent::ToolUse(tool_use) => {
444 let tool_call = mistral::ToolCall {
445 id: tool_use.id.to_string(),
446 content: mistral::ToolCallContent::Function {
447 function: mistral::FunctionContent {
448 name: tool_use.name.to_string(),
449 arguments: serde_json::to_string(&tool_use.input)
450 .unwrap_or_default(),
451 },
452 },
453 };
454
455 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
456 messages.last_mut()
457 {
458 tool_calls.push(tool_call);
459 } else {
460 messages.push(mistral::RequestMessage::Assistant {
461 content: None,
462 tool_calls: vec![tool_call],
463 });
464 }
465 }
466 MessageContent::ToolResult(_) => {
467 // Tool results are not supported in Assistant messages
468 }
469 }
470 }
471 }
472 Role::System => {
473 for content in &message.content {
474 match content {
475 MessageContent::Text(text) => {
476 messages.push(mistral::RequestMessage::System {
477 content: mistral::MessageContent::Plain {
478 content: text.clone(),
479 },
480 });
481 }
482 MessageContent::Thinking { text, .. } => {
483 if model.supports_thinking() {
484 messages.push(mistral::RequestMessage::System {
485 content: mistral::MessageContent::Multipart {
486 content: vec![mistral::MessagePart::Thinking {
487 thinking: vec![mistral::ThinkingPart::Text {
488 text: text.clone(),
489 }],
490 }],
491 },
492 });
493 }
494 }
495 MessageContent::RedactedThinking(_) => {}
496 MessageContent::Image(_)
497 | MessageContent::ToolUse(_)
498 | MessageContent::ToolResult(_) => {
499 // Images and tools are not supported in System messages
500 }
501 }
502 }
503 }
504 }
505 }
506
507 (
508 mistral::Request {
509 model: model.id().to_string(),
510 messages,
511 stream,
512 stream_options: if stream {
513 Some(mistral::StreamOptions {
514 stream_tool_calls: Some(true),
515 })
516 } else {
517 None
518 },
519 max_tokens: max_output_tokens,
520 temperature: request.temperature,
521 response_format: None,
522 tool_choice: match request.tool_choice {
523 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
524 Some(mistral::ToolChoice::Auto)
525 }
526 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
527 Some(mistral::ToolChoice::Any)
528 }
529 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
530 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
531 _ => None,
532 },
533 parallel_tool_calls: if !request.tools.is_empty() {
534 Some(false)
535 } else {
536 None
537 },
538 tools: request
539 .tools
540 .into_iter()
541 .map(|tool| mistral::ToolDefinition::Function {
542 function: mistral::FunctionDefinition {
543 name: tool.name,
544 description: Some(tool.description),
545 parameters: Some(tool.input_schema),
546 },
547 })
548 .collect(),
549 },
550 request.thread_id,
551 )
552}
553
554pub struct MistralEventMapper {
555 tool_calls_by_index: HashMap<usize, RawToolCall>,
556}
557
558impl MistralEventMapper {
559 pub fn new() -> Self {
560 Self {
561 tool_calls_by_index: HashMap::default(),
562 }
563 }
564
565 pub fn map_stream(
566 mut self,
567 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
568 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
569 {
570 events.flat_map(move |event| {
571 futures::stream::iter(match event {
572 Ok(event) => self.map_event(event),
573 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
574 })
575 })
576 }
577
578 pub fn map_event(
579 &mut self,
580 event: mistral::StreamResponse,
581 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
582 let Some(choice) = event.choices.first() else {
583 return vec![Err(LanguageModelCompletionError::from(anyhow!(
584 "Response contained no choices"
585 )))];
586 };
587
588 let mut events = Vec::new();
589 if let Some(content) = choice.delta.content.as_ref() {
590 match content {
591 mistral::MessageContentDelta::Text(text) => {
592 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
593 }
594 mistral::MessageContentDelta::Parts(parts) => {
595 for part in parts {
596 match part {
597 mistral::MessagePart::Text { text } => {
598 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
599 }
600 mistral::MessagePart::Thinking { thinking } => {
601 for tp in thinking.iter().cloned() {
602 match tp {
603 mistral::ThinkingPart::Text { text } => {
604 events.push(Ok(
605 LanguageModelCompletionEvent::Thinking {
606 text,
607 signature: None,
608 },
609 ));
610 }
611 }
612 }
613 }
614 mistral::MessagePart::ImageUrl { .. } => {
615 // We currently don't emit a separate event for images in responses.
616 }
617 }
618 }
619 }
620 }
621 }
622
623 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
624 for tool_call in tool_calls {
625 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
626
627 if let Some(tool_id) = tool_call.id.clone()
628 && !tool_id.is_empty()
629 && tool_id != "null"
630 {
631 entry.id = tool_id;
632 }
633
634 if let Some(function) = tool_call.function.as_ref() {
635 if let Some(name) = function.name.clone()
636 && !name.is_empty()
637 {
638 entry.name = name;
639 }
640
641 if let Some(arguments) = function.arguments.clone() {
642 entry.arguments.push_str(&arguments);
643 }
644 }
645
646 if !entry.id.is_empty() && !entry.name.is_empty() {
647 if let Ok(input) = serde_json::from_str::<serde_json::Value>(
648 &fix_streamed_json(&entry.arguments),
649 ) {
650 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
651 LanguageModelToolUse {
652 id: entry.id.clone().into(),
653 name: entry.name.as_str().into(),
654 is_input_complete: false,
655 input,
656 raw_input: entry.arguments.clone(),
657 thought_signature: None,
658 },
659 )));
660 }
661 }
662 }
663 }
664
665 if let Some(usage) = event.usage {
666 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
667 input_tokens: usage.prompt_tokens,
668 output_tokens: usage.completion_tokens,
669 cache_creation_input_tokens: 0,
670 cache_read_input_tokens: 0,
671 })));
672 }
673
674 if let Some(finish_reason) = choice.finish_reason.as_deref() {
675 match finish_reason {
676 "stop" => {
677 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
678 }
679 "tool_calls" => {
680 events.extend(self.process_tool_calls());
681 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
682 }
683 unexpected => {
684 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
685 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
686 }
687 }
688 }
689
690 events
691 }
692
693 fn process_tool_calls(
694 &mut self,
695 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
696 let mut results = Vec::new();
697
698 for (_, tool_call) in self.tool_calls_by_index.drain() {
699 if tool_call.id.is_empty() || tool_call.name.is_empty() {
700 results.push(Err(LanguageModelCompletionError::from(anyhow!(
701 "Received incomplete tool call: missing id or name"
702 ))));
703 continue;
704 }
705
706 match parse_tool_arguments(&tool_call.arguments) {
707 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
708 LanguageModelToolUse {
709 id: tool_call.id.into(),
710 name: tool_call.name.into(),
711 is_input_complete: true,
712 input,
713 raw_input: tool_call.arguments,
714 thought_signature: None,
715 },
716 ))),
717 Err(error) => {
718 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
719 id: tool_call.id.into(),
720 tool_name: tool_call.name.into(),
721 raw_input: tool_call.arguments.into(),
722 json_parse_error: error.to_string(),
723 }))
724 }
725 }
726 }
727
728 results
729 }
730}
731
732#[derive(Default)]
733struct RawToolCall {
734 id: String,
735 name: String,
736 arguments: String,
737}
738
739struct ConfigurationView {
740 api_key_editor: Entity<InputField>,
741 state: Entity<State>,
742 load_credentials_task: Option<Task<()>>,
743}
744
745impl ConfigurationView {
746 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
747 let api_key_editor =
748 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
749
750 cx.observe(&state, |_, _, cx| {
751 cx.notify();
752 })
753 .detach();
754
755 let load_credentials_task = Some(cx.spawn_in(window, {
756 let state = state.clone();
757 async move |this, cx| {
758 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
759 // We don't log an error, because "not signed in" is also an error.
760 let _ = task.await;
761 }
762
763 this.update(cx, |this, cx| {
764 this.load_credentials_task = None;
765 cx.notify();
766 })
767 .log_err();
768 }
769 }));
770
771 Self {
772 api_key_editor,
773 state,
774 load_credentials_task,
775 }
776 }
777
778 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
779 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
780 if api_key.is_empty() {
781 return;
782 }
783
784 // url changes can cause the editor to be displayed again
785 self.api_key_editor
786 .update(cx, |editor, cx| editor.set_text("", window, cx));
787
788 let state = self.state.clone();
789 cx.spawn_in(window, async move |_, cx| {
790 state
791 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
792 .await
793 })
794 .detach_and_log_err(cx);
795 }
796
797 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
798 self.api_key_editor
799 .update(cx, |editor, cx| editor.set_text("", window, cx));
800
801 let state = self.state.clone();
802 cx.spawn_in(window, async move |_, cx| {
803 state
804 .update(cx, |state, cx| state.set_api_key(None, cx))
805 .await
806 })
807 .detach_and_log_err(cx);
808 }
809
810 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
811 !self.state.read(cx).is_authenticated()
812 }
813}
814
815impl Render for ConfigurationView {
816 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
817 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
818 let configured_card_label = if env_var_set {
819 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
820 } else {
821 let api_url = MistralLanguageModelProvider::api_url(cx);
822 if api_url == MISTRAL_API_URL {
823 "API key configured".to_string()
824 } else {
825 format!("API key configured for {}", api_url)
826 }
827 };
828
829 if self.load_credentials_task.is_some() {
830 div().child(Label::new("Loading credentials...")).into_any()
831 } else if self.should_render_api_key_editor(cx) {
832 v_flex()
833 .size_full()
834 .on_action(cx.listener(Self::save_api_key))
835 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
836 .child(
837 List::new()
838 .child(
839 ListBulletItem::new("")
840 .child(Label::new("Create one by visiting"))
841 .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
842 )
843 .child(
844 ListBulletItem::new("Ensure your Mistral account has credits")
845 )
846 .child(
847 ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
848 ),
849 )
850 .child(self.api_key_editor.clone())
851 .child(
852 Label::new(
853 format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
854 )
855 .size(LabelSize::Small).color(Color::Muted),
856 )
857 .into_any()
858 } else {
859 v_flex()
860 .size_full()
861 .gap_1()
862 .child(
863 ConfiguredApiCard::new(configured_card_label)
864 .disabled(env_var_set)
865 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
866 .when(env_var_set, |this| {
867 this.tooltip_label(format!(
868 "To reset your API key, \
869 unset the {API_KEY_ENV_VAR_NAME} environment variable."
870 ))
871 }),
872 )
873 .into_any()
874 }
875 }
876}
877
878#[cfg(test)]
879mod tests {
880 use super::*;
881 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
882
883 fn tool_call_chunk(
884 id: Option<&str>,
885 name: Option<&str>,
886 arguments: Option<&str>,
887 finish_reason: Option<&str>,
888 ) -> mistral::StreamResponse {
889 mistral::StreamResponse {
890 id: "resp".into(),
891 object: "chat.completion.chunk".into(),
892 created: 0,
893 model: "test".into(),
894 choices: vec![mistral::StreamChoice {
895 index: 0,
896 delta: mistral::StreamDelta {
897 role: None,
898 content: None,
899 tool_calls: if finish_reason.is_some() {
900 None
901 } else {
902 Some(vec![mistral::ToolCallChunk {
903 index: 0,
904 id: id.map(Into::into),
905 function: Some(mistral::FunctionChunk {
906 name: name.map(Into::into),
907 arguments: arguments.map(Into::into),
908 }),
909 }])
910 },
911 },
912 finish_reason: finish_reason.map(Into::into),
913 }],
914 usage: None,
915 }
916 }
917
918 #[test]
919 fn test_streaming_tool_call_ignores_null_id() {
920 // Mistral's streaming API sometimes sends `"id": "null"` in continuation chunks.
921 let mut mapper = MistralEventMapper::new();
922
923 mapper.map_event(tool_call_chunk(
924 Some("real_id_123"),
925 Some("read_file"),
926 Some("{\"path\":"),
927 None,
928 ));
929 mapper.map_event(tool_call_chunk(
930 Some("null"),
931 None,
932 Some("\"a.txt\"}"),
933 None,
934 ));
935 let events = mapper.map_event(tool_call_chunk(None, None, None, Some("tool_calls")));
936
937 let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] else {
938 panic!("Expected first event to be ToolUse, got: {:?}", events[0]);
939 };
940
941 assert_eq!(tool_use.id.to_string(), "real_id_123");
942 assert_eq!(tool_use.name.as_ref(), "read_file");
943 assert_eq!(tool_use.input, serde_json::json!({"path": "a.txt"}));
944 }
945
946 #[test]
947 fn test_into_mistral_basic_conversion() {
948 let request = LanguageModelRequest {
949 messages: vec![
950 LanguageModelRequestMessage {
951 role: Role::System,
952 content: vec![MessageContent::Text("System prompt".into())],
953 cache: false,
954 reasoning_details: None,
955 },
956 LanguageModelRequestMessage {
957 role: Role::User,
958 content: vec![MessageContent::Text("Hello".into())],
959 cache: false,
960 reasoning_details: None,
961 },
962 // should skip empty assistant messages
963 LanguageModelRequestMessage {
964 role: Role::Assistant,
965 content: vec![MessageContent::Text("".into())],
966 cache: false,
967 reasoning_details: None,
968 },
969 ],
970 temperature: Some(0.5),
971 tools: vec![],
972 tool_choice: None,
973 thread_id: Some("abcdef".into()),
974 prompt_id: None,
975 intent: None,
976 stop: vec![],
977 thinking_allowed: true,
978 thinking_effort: None,
979 speed: Default::default(),
980 };
981
982 let (mistral_request, affinity) =
983 into_mistral(request, mistral::Model::MistralSmallLatest, None);
984
985 assert_eq!(mistral_request.model, "mistral-small-latest");
986 assert_eq!(mistral_request.temperature, Some(0.5));
987 assert_eq!(mistral_request.messages.len(), 2);
988 assert!(mistral_request.stream);
989 assert_eq!(affinity, Some("abcdef".into()));
990 }
991
992 #[test]
993 fn test_into_mistral_with_image() {
994 let request = LanguageModelRequest {
995 messages: vec![LanguageModelRequestMessage {
996 role: Role::User,
997 content: vec![
998 MessageContent::Text("What's in this image?".into()),
999 MessageContent::Image(LanguageModelImage {
1000 source: "base64data".into(),
1001 size: None,
1002 }),
1003 ],
1004 cache: false,
1005 reasoning_details: None,
1006 }],
1007 tools: vec![],
1008 tool_choice: None,
1009 temperature: None,
1010 thread_id: None,
1011 prompt_id: None,
1012 intent: None,
1013 stop: vec![],
1014 thinking_allowed: true,
1015 thinking_effort: None,
1016 speed: None,
1017 };
1018
1019 let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1020
1021 assert_eq!(mistral_request.messages.len(), 1);
1022 assert!(matches!(
1023 &mistral_request.messages[0],
1024 mistral::RequestMessage::User {
1025 content: mistral::MessageContent::Multipart { .. }
1026 }
1027 ));
1028
1029 if let mistral::RequestMessage::User {
1030 content: mistral::MessageContent::Multipart { content },
1031 } = &mistral_request.messages[0]
1032 {
1033 assert_eq!(content.len(), 2);
1034 assert!(matches!(
1035 &content[0],
1036 mistral::MessagePart::Text { text } if text == "What's in this image?"
1037 ));
1038 assert!(matches!(
1039 &content[1],
1040 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1041 ));
1042 }
1043 }
1044}