1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4
5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
10 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
13 LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
14};
15pub use mistral::{MISTRAL_API_URL, StreamResponse};
16pub use settings::MistralAvailableModel as AvailableModel;
17use settings::{Settings, SettingsStore};
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
23use ui_input::InputField;
24use util::ResultExt;
25
26use language_model::util::{fix_streamed_json, parse_tool_arguments};
27
28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
30
31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
33
34#[derive(Default, Clone, Debug, PartialEq)]
35pub struct MistralSettings {
36 pub api_url: String,
37 pub available_models: Vec<AvailableModel>,
38}
39
40pub struct MistralLanguageModelProvider {
41 http_client: Arc<dyn HttpClient>,
42 pub state: Entity<State>,
43}
44
45pub struct State {
46 api_key_state: ApiKeyState,
47 credentials_provider: Arc<dyn CredentialsProvider>,
48}
49
50impl State {
51 fn is_authenticated(&self) -> bool {
52 self.api_key_state.has_key()
53 }
54
55 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
56 let credentials_provider = self.credentials_provider.clone();
57 let api_url = MistralLanguageModelProvider::api_url(cx);
58 self.api_key_state.store(
59 api_url,
60 api_key,
61 |this| &mut this.api_key_state,
62 credentials_provider,
63 cx,
64 )
65 }
66
67 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
68 let credentials_provider = self.credentials_provider.clone();
69 let api_url = MistralLanguageModelProvider::api_url(cx);
70 self.api_key_state.load_if_needed(
71 api_url,
72 |this| &mut this.api_key_state,
73 credentials_provider,
74 cx,
75 )
76 }
77}
78
79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
80
81impl Global for GlobalMistralLanguageModelProvider {}
82
83impl MistralLanguageModelProvider {
84 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
85 cx.try_global::<GlobalMistralLanguageModelProvider>()
86 .map(|this| &this.0)
87 }
88
89 pub fn global(
90 http_client: Arc<dyn HttpClient>,
91 credentials_provider: Arc<dyn CredentialsProvider>,
92 cx: &mut App,
93 ) -> Arc<Self> {
94 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
95 return this.0.clone();
96 }
97 let state = cx.new(|cx| {
98 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
99 let credentials_provider = this.credentials_provider.clone();
100 let api_url = Self::api_url(cx);
101 this.api_key_state.handle_url_change(
102 api_url,
103 |this| &mut this.api_key_state,
104 credentials_provider,
105 cx,
106 );
107 cx.notify();
108 })
109 .detach();
110 State {
111 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
112 credentials_provider,
113 }
114 });
115
116 let this = Arc::new(Self { http_client, state });
117 cx.set_global(GlobalMistralLanguageModelProvider(this));
118 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
119 }
120
121 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
122 Arc::new(MistralLanguageModel {
123 id: LanguageModelId::from(model.id().to_string()),
124 model,
125 state: self.state.clone(),
126 http_client: self.http_client.clone(),
127 request_limiter: RateLimiter::new(4),
128 })
129 }
130
131 fn settings(cx: &App) -> &MistralSettings {
132 &crate::AllLanguageModelSettings::get_global(cx).mistral
133 }
134
135 pub fn api_url(cx: &App) -> SharedString {
136 let api_url = &Self::settings(cx).api_url;
137 if api_url.is_empty() {
138 mistral::MISTRAL_API_URL.into()
139 } else {
140 SharedString::new(api_url.as_str())
141 }
142 }
143}
144
145impl LanguageModelProviderState for MistralLanguageModelProvider {
146 type ObservableEntity = State;
147
148 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
149 Some(self.state.clone())
150 }
151}
152
153impl LanguageModelProvider for MistralLanguageModelProvider {
154 fn id(&self) -> LanguageModelProviderId {
155 PROVIDER_ID
156 }
157
158 fn name(&self) -> LanguageModelProviderName {
159 PROVIDER_NAME
160 }
161
162 fn icon(&self) -> IconOrSvg {
163 IconOrSvg::Icon(IconName::AiMistral)
164 }
165
166 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
167 Some(self.create_language_model(mistral::Model::default()))
168 }
169
170 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
171 Some(self.create_language_model(mistral::Model::default_fast()))
172 }
173
174 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
175 let mut models = BTreeMap::default();
176
177 // Add base models from mistral::Model::iter()
178 for model in mistral::Model::iter() {
179 if !matches!(model, mistral::Model::Custom { .. }) {
180 models.insert(model.id().to_string(), model);
181 }
182 }
183
184 // Override with available models from settings
185 for model in &Self::settings(cx).available_models {
186 models.insert(
187 model.name.clone(),
188 mistral::Model::Custom {
189 name: model.name.clone(),
190 display_name: model.display_name.clone(),
191 max_tokens: model.max_tokens,
192 max_output_tokens: model.max_output_tokens,
193 max_completion_tokens: model.max_completion_tokens,
194 supports_tools: model.supports_tools,
195 supports_images: model.supports_images,
196 supports_thinking: model.supports_thinking,
197 },
198 );
199 }
200
201 models
202 .into_values()
203 .map(|model| {
204 Arc::new(MistralLanguageModel {
205 id: LanguageModelId::from(model.id().to_string()),
206 model,
207 state: self.state.clone(),
208 http_client: self.http_client.clone(),
209 request_limiter: RateLimiter::new(4),
210 }) as Arc<dyn LanguageModel>
211 })
212 .collect()
213 }
214
215 fn is_authenticated(&self, cx: &App) -> bool {
216 self.state.read(cx).is_authenticated()
217 }
218
219 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
220 self.state.update(cx, |state, cx| state.authenticate(cx))
221 }
222
223 fn configuration_view(
224 &self,
225 _target_agent: language_model::ConfigurationViewTargetAgent,
226 window: &mut Window,
227 cx: &mut App,
228 ) -> AnyView {
229 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
230 .into()
231 }
232
233 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
234 self.state
235 .update(cx, |state, cx| state.set_api_key(None, cx))
236 }
237}
238
239pub struct MistralLanguageModel {
240 id: LanguageModelId,
241 model: mistral::Model,
242 state: Entity<State>,
243 http_client: Arc<dyn HttpClient>,
244 request_limiter: RateLimiter,
245}
246
247impl MistralLanguageModel {
248 fn stream_completion(
249 &self,
250 request: mistral::Request,
251 affinity: Option<String>,
252 cx: &AsyncApp,
253 ) -> BoxFuture<
254 'static,
255 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
256 > {
257 let http_client = self.http_client.clone();
258
259 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
260 let api_url = MistralLanguageModelProvider::api_url(cx);
261 (state.api_key_state.key(&api_url), api_url)
262 });
263
264 let future = self.request_limiter.stream(async move {
265 let Some(api_key) = api_key else {
266 return Err(LanguageModelCompletionError::NoApiKey {
267 provider: PROVIDER_NAME,
268 });
269 };
270 let request = mistral::stream_completion(
271 http_client.as_ref(),
272 &api_url,
273 &api_key,
274 request,
275 affinity,
276 );
277 let response = request.await?;
278 Ok(response)
279 });
280
281 async move { Ok(future.await?.boxed()) }.boxed()
282 }
283}
284
285impl LanguageModel for MistralLanguageModel {
286 fn id(&self) -> LanguageModelId {
287 self.id.clone()
288 }
289
290 fn name(&self) -> LanguageModelName {
291 LanguageModelName::from(self.model.display_name().to_string())
292 }
293
294 fn provider_id(&self) -> LanguageModelProviderId {
295 PROVIDER_ID
296 }
297
298 fn provider_name(&self) -> LanguageModelProviderName {
299 PROVIDER_NAME
300 }
301
302 fn supports_tools(&self) -> bool {
303 self.model.supports_tools()
304 }
305
306 fn supports_streaming_tools(&self) -> bool {
307 true
308 }
309
310 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
311 self.model.supports_tools()
312 }
313
314 fn supports_images(&self) -> bool {
315 self.model.supports_images()
316 }
317
318 fn telemetry_id(&self) -> String {
319 format!("mistral/{}", self.model.id())
320 }
321
322 fn max_token_count(&self) -> u64 {
323 self.model.max_token_count()
324 }
325
326 fn max_output_tokens(&self) -> Option<u64> {
327 self.model.max_output_tokens()
328 }
329
330 fn count_tokens(
331 &self,
332 request: LanguageModelRequest,
333 cx: &App,
334 ) -> BoxFuture<'static, Result<u64>> {
335 cx.background_spawn(async move {
336 let messages = request
337 .messages
338 .into_iter()
339 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
340 role: match message.role {
341 Role::User => "user".into(),
342 Role::Assistant => "assistant".into(),
343 Role::System => "system".into(),
344 },
345 content: Some(message.string_contents()),
346 name: None,
347 function_call: None,
348 })
349 .collect::<Vec<_>>();
350
351 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
352 })
353 .boxed()
354 }
355
356 fn stream_completion(
357 &self,
358 request: LanguageModelRequest,
359 cx: &AsyncApp,
360 ) -> BoxFuture<
361 'static,
362 Result<
363 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
364 LanguageModelCompletionError,
365 >,
366 > {
367 let (request, affinity) =
368 into_mistral(request, self.model.clone(), self.max_output_tokens());
369 let stream = self.stream_completion(request, affinity, cx);
370
371 async move {
372 let stream = stream.await?;
373 let mapper = MistralEventMapper::new();
374 Ok(mapper.map_stream(stream).boxed())
375 }
376 .boxed()
377 }
378}
379
380pub fn into_mistral(
381 request: LanguageModelRequest,
382 model: mistral::Model,
383 max_output_tokens: Option<u64>,
384) -> (mistral::Request, Option<String>) {
385 let stream = true;
386
387 let mut messages = Vec::new();
388 for message in &request.messages {
389 match message.role {
390 Role::User => {
391 let mut message_content = mistral::MessageContent::empty();
392 for content in &message.content {
393 match content {
394 MessageContent::Text(text) => {
395 message_content
396 .push_part(mistral::MessagePart::Text { text: text.clone() });
397 }
398 MessageContent::Image(image_content) => {
399 if model.supports_images() {
400 message_content.push_part(mistral::MessagePart::ImageUrl {
401 image_url: image_content.to_base64_url(),
402 });
403 }
404 }
405 MessageContent::Thinking { text, .. } => {
406 if model.supports_thinking() {
407 message_content.push_part(mistral::MessagePart::Thinking {
408 thinking: vec![mistral::ThinkingPart::Text {
409 text: text.clone(),
410 }],
411 });
412 }
413 }
414 MessageContent::RedactedThinking(_) => {}
415 MessageContent::ToolUse(_) => {
416 // Tool use is not supported in User messages for Mistral
417 }
418 MessageContent::ToolResult(tool_result) => {
419 let tool_content = match &tool_result.content {
420 LanguageModelToolResultContent::Text(text) => text.to_string(),
421 LanguageModelToolResultContent::Image(_) => {
422 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
423 }
424 };
425 messages.push(mistral::RequestMessage::Tool {
426 content: tool_content,
427 tool_call_id: tool_result.tool_use_id.to_string(),
428 });
429 }
430 }
431 }
432 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
433 {
434 messages.push(mistral::RequestMessage::User {
435 content: message_content,
436 });
437 }
438 }
439 Role::Assistant => {
440 for content in &message.content {
441 match content {
442 MessageContent::Text(text) if text.is_empty() => {
443 // Mistral API returns a 400 if there's neither content nor tool_calls
444 }
445 MessageContent::Text(text) => {
446 messages.push(mistral::RequestMessage::Assistant {
447 content: Some(mistral::MessageContent::Plain {
448 content: text.clone(),
449 }),
450 tool_calls: Vec::new(),
451 });
452 }
453 MessageContent::Thinking { text, .. } => {
454 if model.supports_thinking() {
455 messages.push(mistral::RequestMessage::Assistant {
456 content: Some(mistral::MessageContent::Multipart {
457 content: vec![mistral::MessagePart::Thinking {
458 thinking: vec![mistral::ThinkingPart::Text {
459 text: text.clone(),
460 }],
461 }],
462 }),
463 tool_calls: Vec::new(),
464 });
465 }
466 }
467 MessageContent::RedactedThinking(_) => {}
468 MessageContent::Image(_) => {}
469 MessageContent::ToolUse(tool_use) => {
470 let tool_call = mistral::ToolCall {
471 id: tool_use.id.to_string(),
472 content: mistral::ToolCallContent::Function {
473 function: mistral::FunctionContent {
474 name: tool_use.name.to_string(),
475 arguments: serde_json::to_string(&tool_use.input)
476 .unwrap_or_default(),
477 },
478 },
479 };
480
481 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
482 messages.last_mut()
483 {
484 tool_calls.push(tool_call);
485 } else {
486 messages.push(mistral::RequestMessage::Assistant {
487 content: None,
488 tool_calls: vec![tool_call],
489 });
490 }
491 }
492 MessageContent::ToolResult(_) => {
493 // Tool results are not supported in Assistant messages
494 }
495 }
496 }
497 }
498 Role::System => {
499 for content in &message.content {
500 match content {
501 MessageContent::Text(text) => {
502 messages.push(mistral::RequestMessage::System {
503 content: mistral::MessageContent::Plain {
504 content: text.clone(),
505 },
506 });
507 }
508 MessageContent::Thinking { text, .. } => {
509 if model.supports_thinking() {
510 messages.push(mistral::RequestMessage::System {
511 content: mistral::MessageContent::Multipart {
512 content: vec![mistral::MessagePart::Thinking {
513 thinking: vec![mistral::ThinkingPart::Text {
514 text: text.clone(),
515 }],
516 }],
517 },
518 });
519 }
520 }
521 MessageContent::RedactedThinking(_) => {}
522 MessageContent::Image(_)
523 | MessageContent::ToolUse(_)
524 | MessageContent::ToolResult(_) => {
525 // Images and tools are not supported in System messages
526 }
527 }
528 }
529 }
530 }
531 }
532
533 (
534 mistral::Request {
535 model: model.id().to_string(),
536 messages,
537 stream,
538 stream_options: if stream {
539 Some(mistral::StreamOptions {
540 stream_tool_calls: Some(true),
541 })
542 } else {
543 None
544 },
545 max_tokens: max_output_tokens,
546 temperature: request.temperature,
547 response_format: None,
548 tool_choice: match request.tool_choice {
549 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
550 Some(mistral::ToolChoice::Auto)
551 }
552 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
553 Some(mistral::ToolChoice::Any)
554 }
555 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
556 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
557 _ => None,
558 },
559 parallel_tool_calls: if !request.tools.is_empty() {
560 Some(false)
561 } else {
562 None
563 },
564 tools: request
565 .tools
566 .into_iter()
567 .map(|tool| mistral::ToolDefinition::Function {
568 function: mistral::FunctionDefinition {
569 name: tool.name,
570 description: Some(tool.description),
571 parameters: Some(tool.input_schema),
572 },
573 })
574 .collect(),
575 },
576 request.thread_id,
577 )
578}
579
580pub struct MistralEventMapper {
581 tool_calls_by_index: HashMap<usize, RawToolCall>,
582}
583
584impl MistralEventMapper {
585 pub fn new() -> Self {
586 Self {
587 tool_calls_by_index: HashMap::default(),
588 }
589 }
590
591 pub fn map_stream(
592 mut self,
593 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
594 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
595 {
596 events.flat_map(move |event| {
597 futures::stream::iter(match event {
598 Ok(event) => self.map_event(event),
599 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
600 })
601 })
602 }
603
604 pub fn map_event(
605 &mut self,
606 event: mistral::StreamResponse,
607 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
608 let Some(choice) = event.choices.first() else {
609 return vec![Err(LanguageModelCompletionError::from(anyhow!(
610 "Response contained no choices"
611 )))];
612 };
613
614 let mut events = Vec::new();
615 if let Some(content) = choice.delta.content.as_ref() {
616 match content {
617 mistral::MessageContentDelta::Text(text) => {
618 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
619 }
620 mistral::MessageContentDelta::Parts(parts) => {
621 for part in parts {
622 match part {
623 mistral::MessagePart::Text { text } => {
624 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
625 }
626 mistral::MessagePart::Thinking { thinking } => {
627 for tp in thinking.iter().cloned() {
628 match tp {
629 mistral::ThinkingPart::Text { text } => {
630 events.push(Ok(
631 LanguageModelCompletionEvent::Thinking {
632 text,
633 signature: None,
634 },
635 ));
636 }
637 }
638 }
639 }
640 mistral::MessagePart::ImageUrl { .. } => {
641 // We currently don't emit a separate event for images in responses.
642 }
643 }
644 }
645 }
646 }
647 }
648
649 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
650 for tool_call in tool_calls {
651 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
652
653 if let Some(tool_id) = tool_call.id.clone()
654 && !tool_id.is_empty()
655 && tool_id != "null"
656 {
657 entry.id = tool_id;
658 }
659
660 if let Some(function) = tool_call.function.as_ref() {
661 if let Some(name) = function.name.clone()
662 && !name.is_empty()
663 {
664 entry.name = name;
665 }
666
667 if let Some(arguments) = function.arguments.clone() {
668 entry.arguments.push_str(&arguments);
669 }
670 }
671
672 if !entry.id.is_empty() && !entry.name.is_empty() {
673 if let Ok(input) = serde_json::from_str::<serde_json::Value>(
674 &fix_streamed_json(&entry.arguments),
675 ) {
676 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
677 LanguageModelToolUse {
678 id: entry.id.clone().into(),
679 name: entry.name.as_str().into(),
680 is_input_complete: false,
681 input,
682 raw_input: entry.arguments.clone(),
683 thought_signature: None,
684 },
685 )));
686 }
687 }
688 }
689 }
690
691 if let Some(usage) = event.usage {
692 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
693 input_tokens: usage.prompt_tokens,
694 output_tokens: usage.completion_tokens,
695 cache_creation_input_tokens: 0,
696 cache_read_input_tokens: 0,
697 })));
698 }
699
700 if let Some(finish_reason) = choice.finish_reason.as_deref() {
701 match finish_reason {
702 "stop" => {
703 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
704 }
705 "tool_calls" => {
706 events.extend(self.process_tool_calls());
707 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
708 }
709 unexpected => {
710 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
711 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
712 }
713 }
714 }
715
716 events
717 }
718
719 fn process_tool_calls(
720 &mut self,
721 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
722 let mut results = Vec::new();
723
724 for (_, tool_call) in self.tool_calls_by_index.drain() {
725 if tool_call.id.is_empty() || tool_call.name.is_empty() {
726 results.push(Err(LanguageModelCompletionError::from(anyhow!(
727 "Received incomplete tool call: missing id or name"
728 ))));
729 continue;
730 }
731
732 match parse_tool_arguments(&tool_call.arguments) {
733 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
734 LanguageModelToolUse {
735 id: tool_call.id.into(),
736 name: tool_call.name.into(),
737 is_input_complete: true,
738 input,
739 raw_input: tool_call.arguments,
740 thought_signature: None,
741 },
742 ))),
743 Err(error) => {
744 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
745 id: tool_call.id.into(),
746 tool_name: tool_call.name.into(),
747 raw_input: tool_call.arguments.into(),
748 json_parse_error: error.to_string(),
749 }))
750 }
751 }
752 }
753
754 results
755 }
756}
757
758#[derive(Default)]
759struct RawToolCall {
760 id: String,
761 name: String,
762 arguments: String,
763}
764
765struct ConfigurationView {
766 api_key_editor: Entity<InputField>,
767 state: Entity<State>,
768 load_credentials_task: Option<Task<()>>,
769}
770
771impl ConfigurationView {
772 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
773 let api_key_editor =
774 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
775
776 cx.observe(&state, |_, _, cx| {
777 cx.notify();
778 })
779 .detach();
780
781 let load_credentials_task = Some(cx.spawn_in(window, {
782 let state = state.clone();
783 async move |this, cx| {
784 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
785 // We don't log an error, because "not signed in" is also an error.
786 let _ = task.await;
787 }
788
789 this.update(cx, |this, cx| {
790 this.load_credentials_task = None;
791 cx.notify();
792 })
793 .log_err();
794 }
795 }));
796
797 Self {
798 api_key_editor,
799 state,
800 load_credentials_task,
801 }
802 }
803
804 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
805 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
806 if api_key.is_empty() {
807 return;
808 }
809
810 // url changes can cause the editor to be displayed again
811 self.api_key_editor
812 .update(cx, |editor, cx| editor.set_text("", window, cx));
813
814 let state = self.state.clone();
815 cx.spawn_in(window, async move |_, cx| {
816 state
817 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
818 .await
819 })
820 .detach_and_log_err(cx);
821 }
822
823 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
824 self.api_key_editor
825 .update(cx, |editor, cx| editor.set_text("", window, cx));
826
827 let state = self.state.clone();
828 cx.spawn_in(window, async move |_, cx| {
829 state
830 .update(cx, |state, cx| state.set_api_key(None, cx))
831 .await
832 })
833 .detach_and_log_err(cx);
834 }
835
836 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
837 !self.state.read(cx).is_authenticated()
838 }
839}
840
841impl Render for ConfigurationView {
842 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
843 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
844 let configured_card_label = if env_var_set {
845 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
846 } else {
847 let api_url = MistralLanguageModelProvider::api_url(cx);
848 if api_url == MISTRAL_API_URL {
849 "API key configured".to_string()
850 } else {
851 format!("API key configured for {}", api_url)
852 }
853 };
854
855 if self.load_credentials_task.is_some() {
856 div().child(Label::new("Loading credentials...")).into_any()
857 } else if self.should_render_api_key_editor(cx) {
858 v_flex()
859 .size_full()
860 .on_action(cx.listener(Self::save_api_key))
861 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
862 .child(
863 List::new()
864 .child(
865 ListBulletItem::new("")
866 .child(Label::new("Create one by visiting"))
867 .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
868 )
869 .child(
870 ListBulletItem::new("Ensure your Mistral account has credits")
871 )
872 .child(
873 ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
874 ),
875 )
876 .child(self.api_key_editor.clone())
877 .child(
878 Label::new(
879 format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
880 )
881 .size(LabelSize::Small).color(Color::Muted),
882 )
883 .into_any()
884 } else {
885 v_flex()
886 .size_full()
887 .gap_1()
888 .child(
889 ConfiguredApiCard::new(configured_card_label)
890 .disabled(env_var_set)
891 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
892 .when(env_var_set, |this| {
893 this.tooltip_label(format!(
894 "To reset your API key, \
895 unset the {API_KEY_ENV_VAR_NAME} environment variable."
896 ))
897 }),
898 )
899 .into_any()
900 }
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
908
909 fn tool_call_chunk(
910 id: Option<&str>,
911 name: Option<&str>,
912 arguments: Option<&str>,
913 finish_reason: Option<&str>,
914 ) -> mistral::StreamResponse {
915 mistral::StreamResponse {
916 id: "resp".into(),
917 object: "chat.completion.chunk".into(),
918 created: 0,
919 model: "test".into(),
920 choices: vec![mistral::StreamChoice {
921 index: 0,
922 delta: mistral::StreamDelta {
923 role: None,
924 content: None,
925 tool_calls: if finish_reason.is_some() {
926 None
927 } else {
928 Some(vec![mistral::ToolCallChunk {
929 index: 0,
930 id: id.map(Into::into),
931 function: Some(mistral::FunctionChunk {
932 name: name.map(Into::into),
933 arguments: arguments.map(Into::into),
934 }),
935 }])
936 },
937 },
938 finish_reason: finish_reason.map(Into::into),
939 }],
940 usage: None,
941 }
942 }
943
944 #[test]
945 fn test_streaming_tool_call_ignores_null_id() {
946 // Mistral's streaming API sometimes sends `"id": "null"` in continuation chunks.
947 let mut mapper = MistralEventMapper::new();
948
949 mapper.map_event(tool_call_chunk(
950 Some("real_id_123"),
951 Some("read_file"),
952 Some("{\"path\":"),
953 None,
954 ));
955 mapper.map_event(tool_call_chunk(
956 Some("null"),
957 None,
958 Some("\"a.txt\"}"),
959 None,
960 ));
961 let events = mapper.map_event(tool_call_chunk(None, None, None, Some("tool_calls")));
962
963 let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] else {
964 panic!("Expected first event to be ToolUse, got: {:?}", events[0]);
965 };
966
967 assert_eq!(tool_use.id.to_string(), "real_id_123");
968 assert_eq!(tool_use.name.as_ref(), "read_file");
969 assert_eq!(tool_use.input, serde_json::json!({"path": "a.txt"}));
970 }
971
972 #[test]
973 fn test_into_mistral_basic_conversion() {
974 let request = LanguageModelRequest {
975 messages: vec![
976 LanguageModelRequestMessage {
977 role: Role::System,
978 content: vec![MessageContent::Text("System prompt".into())],
979 cache: false,
980 reasoning_details: None,
981 },
982 LanguageModelRequestMessage {
983 role: Role::User,
984 content: vec![MessageContent::Text("Hello".into())],
985 cache: false,
986 reasoning_details: None,
987 },
988 // should skip empty assistant messages
989 LanguageModelRequestMessage {
990 role: Role::Assistant,
991 content: vec![MessageContent::Text("".into())],
992 cache: false,
993 reasoning_details: None,
994 },
995 ],
996 temperature: Some(0.5),
997 tools: vec![],
998 tool_choice: None,
999 thread_id: Some("abcdef".into()),
1000 prompt_id: None,
1001 intent: None,
1002 stop: vec![],
1003 thinking_allowed: true,
1004 thinking_effort: None,
1005 speed: Default::default(),
1006 };
1007
1008 let (mistral_request, affinity) =
1009 into_mistral(request, mistral::Model::MistralSmallLatest, None);
1010
1011 assert_eq!(mistral_request.model, "mistral-small-latest");
1012 assert_eq!(mistral_request.temperature, Some(0.5));
1013 assert_eq!(mistral_request.messages.len(), 2);
1014 assert!(mistral_request.stream);
1015 assert_eq!(affinity, Some("abcdef".into()));
1016 }
1017
1018 #[test]
1019 fn test_into_mistral_with_image() {
1020 let request = LanguageModelRequest {
1021 messages: vec![LanguageModelRequestMessage {
1022 role: Role::User,
1023 content: vec![
1024 MessageContent::Text("What's in this image?".into()),
1025 MessageContent::Image(LanguageModelImage {
1026 source: "base64data".into(),
1027 size: None,
1028 }),
1029 ],
1030 cache: false,
1031 reasoning_details: None,
1032 }],
1033 tools: vec![],
1034 tool_choice: None,
1035 temperature: None,
1036 thread_id: None,
1037 prompt_id: None,
1038 intent: None,
1039 stop: vec![],
1040 thinking_allowed: true,
1041 thinking_effort: None,
1042 speed: None,
1043 };
1044
1045 let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1046
1047 assert_eq!(mistral_request.messages.len(), 1);
1048 assert!(matches!(
1049 &mistral_request.messages[0],
1050 mistral::RequestMessage::User {
1051 content: mistral::MessageContent::Multipart { .. }
1052 }
1053 ));
1054
1055 if let mistral::RequestMessage::User {
1056 content: mistral::MessageContent::Multipart { content },
1057 } = &mistral_request.messages[0]
1058 {
1059 assert_eq!(content.len(), 2);
1060 assert!(matches!(
1061 &content[0],
1062 mistral::MessagePart::Text { text } if text == "What's in this image?"
1063 ));
1064 assert!(matches!(
1065 &content[1],
1066 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1067 ));
1068 }
1069 }
1070}