1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use fs::Fs;
4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
15pub use settings::MistralAvailableModel as AvailableModel;
16use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{List, prelude::*};
23use ui_input::InputField;
24use util::ResultExt;
25use zed_env_vars::{EnvVar, env_var};
26
27use crate::ui::ConfiguredApiCard;
28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
29
30const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
31const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
32
33const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
35
36const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
37static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
38
39#[derive(Default, Clone, Debug, PartialEq)]
40pub struct MistralSettings {
41 pub api_url: String,
42 pub available_models: Vec<AvailableModel>,
43}
44
45pub struct MistralLanguageModelProvider {
46 http_client: Arc<dyn HttpClient>,
47 state: Entity<State>,
48}
49
50pub struct State {
51 api_key_state: ApiKeyState,
52 codestral_api_key_state: ApiKeyState,
53}
54
55impl State {
56 fn is_authenticated(&self) -> bool {
57 self.api_key_state.has_key()
58 }
59
60 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
61 let api_url = MistralLanguageModelProvider::api_url(cx);
62 self.api_key_state
63 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
64 }
65
66 fn set_codestral_api_key(
67 &mut self,
68 api_key: Option<String>,
69 cx: &mut Context<Self>,
70 ) -> Task<Result<()>> {
71 self.codestral_api_key_state.store(
72 CODESTRAL_API_URL.into(),
73 api_key,
74 |this| &mut this.codestral_api_key_state,
75 cx,
76 )
77 }
78
79 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
80 let api_url = MistralLanguageModelProvider::api_url(cx);
81 self.api_key_state.load_if_needed(
82 api_url,
83 &API_KEY_ENV_VAR,
84 |this| &mut this.api_key_state,
85 cx,
86 )
87 }
88
89 fn authenticate_codestral(
90 &mut self,
91 cx: &mut Context<Self>,
92 ) -> Task<Result<(), AuthenticateError>> {
93 self.codestral_api_key_state.load_if_needed(
94 CODESTRAL_API_URL.into(),
95 &CODESTRAL_API_KEY_ENV_VAR,
96 |this| &mut this.codestral_api_key_state,
97 cx,
98 )
99 }
100}
101
102struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
103
104impl Global for GlobalMistralLanguageModelProvider {}
105
106impl MistralLanguageModelProvider {
107 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
108 cx.try_global::<GlobalMistralLanguageModelProvider>()
109 .map(|this| &this.0)
110 }
111
112 pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
113 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
114 return this.0.clone();
115 }
116 let state = cx.new(|cx| {
117 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
118 let api_url = Self::api_url(cx);
119 this.api_key_state.handle_url_change(
120 api_url,
121 &API_KEY_ENV_VAR,
122 |this| &mut this.api_key_state,
123 cx,
124 );
125 cx.notify();
126 })
127 .detach();
128 State {
129 api_key_state: ApiKeyState::new(Self::api_url(cx)),
130 codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
131 }
132 });
133
134 let this = Arc::new(Self { http_client, state });
135 cx.set_global(GlobalMistralLanguageModelProvider(this));
136 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
137 }
138
139 pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
140 self.state
141 .update(cx, |state, cx| state.authenticate_codestral(cx))
142 }
143
144 pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
145 self.state.read(cx).codestral_api_key_state.key(url)
146 }
147
148 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
149 Arc::new(MistralLanguageModel {
150 id: LanguageModelId::from(model.id().to_string()),
151 model,
152 state: self.state.clone(),
153 http_client: self.http_client.clone(),
154 request_limiter: RateLimiter::new(4),
155 })
156 }
157
158 fn settings(cx: &App) -> &MistralSettings {
159 &crate::AllLanguageModelSettings::get_global(cx).mistral
160 }
161
162 fn api_url(cx: &App) -> SharedString {
163 let api_url = &Self::settings(cx).api_url;
164 if api_url.is_empty() {
165 mistral::MISTRAL_API_URL.into()
166 } else {
167 SharedString::new(api_url.as_str())
168 }
169 }
170}
171
172impl LanguageModelProviderState for MistralLanguageModelProvider {
173 type ObservableEntity = State;
174
175 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
176 Some(self.state.clone())
177 }
178}
179
180impl LanguageModelProvider for MistralLanguageModelProvider {
181 fn id(&self) -> LanguageModelProviderId {
182 PROVIDER_ID
183 }
184
185 fn name(&self) -> LanguageModelProviderName {
186 PROVIDER_NAME
187 }
188
189 fn icon(&self) -> IconName {
190 IconName::AiMistral
191 }
192
193 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
194 Some(self.create_language_model(mistral::Model::default()))
195 }
196
197 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
198 Some(self.create_language_model(mistral::Model::default_fast()))
199 }
200
201 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
202 let mut models = BTreeMap::default();
203
204 // Add base models from mistral::Model::iter()
205 for model in mistral::Model::iter() {
206 if !matches!(model, mistral::Model::Custom { .. }) {
207 models.insert(model.id().to_string(), model);
208 }
209 }
210
211 // Override with available models from settings
212 for model in &Self::settings(cx).available_models {
213 models.insert(
214 model.name.clone(),
215 mistral::Model::Custom {
216 name: model.name.clone(),
217 display_name: model.display_name.clone(),
218 max_tokens: model.max_tokens,
219 max_output_tokens: model.max_output_tokens,
220 max_completion_tokens: model.max_completion_tokens,
221 supports_tools: model.supports_tools,
222 supports_images: model.supports_images,
223 supports_thinking: model.supports_thinking,
224 },
225 );
226 }
227
228 models
229 .into_values()
230 .map(|model| {
231 Arc::new(MistralLanguageModel {
232 id: LanguageModelId::from(model.id().to_string()),
233 model,
234 state: self.state.clone(),
235 http_client: self.http_client.clone(),
236 request_limiter: RateLimiter::new(4),
237 }) as Arc<dyn LanguageModel>
238 })
239 .collect()
240 }
241
242 fn is_authenticated(&self, cx: &App) -> bool {
243 self.state.read(cx).is_authenticated()
244 }
245
246 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
247 self.state.update(cx, |state, cx| state.authenticate(cx))
248 }
249
250 fn configuration_view(
251 &self,
252 _target_agent: language_model::ConfigurationViewTargetAgent,
253 window: &mut Window,
254 cx: &mut App,
255 ) -> AnyView {
256 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
257 .into()
258 }
259
260 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
261 self.state
262 .update(cx, |state, cx| state.set_api_key(None, cx))
263 }
264}
265
266pub struct MistralLanguageModel {
267 id: LanguageModelId,
268 model: mistral::Model,
269 state: Entity<State>,
270 http_client: Arc<dyn HttpClient>,
271 request_limiter: RateLimiter,
272}
273
274impl MistralLanguageModel {
275 fn stream_completion(
276 &self,
277 request: mistral::Request,
278 cx: &AsyncApp,
279 ) -> BoxFuture<
280 'static,
281 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
282 > {
283 let http_client = self.http_client.clone();
284
285 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
286 let api_url = MistralLanguageModelProvider::api_url(cx);
287 (state.api_key_state.key(&api_url), api_url)
288 }) else {
289 return future::ready(Err(anyhow!("App state dropped"))).boxed();
290 };
291
292 let future = self.request_limiter.stream(async move {
293 let Some(api_key) = api_key else {
294 return Err(LanguageModelCompletionError::NoApiKey {
295 provider: PROVIDER_NAME,
296 });
297 };
298 let request =
299 mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
300 let response = request.await?;
301 Ok(response)
302 });
303
304 async move { Ok(future.await?.boxed()) }.boxed()
305 }
306}
307
308impl LanguageModel for MistralLanguageModel {
309 fn id(&self) -> LanguageModelId {
310 self.id.clone()
311 }
312
313 fn name(&self) -> LanguageModelName {
314 LanguageModelName::from(self.model.display_name().to_string())
315 }
316
317 fn provider_id(&self) -> LanguageModelProviderId {
318 PROVIDER_ID
319 }
320
321 fn provider_name(&self) -> LanguageModelProviderName {
322 PROVIDER_NAME
323 }
324
325 fn supports_tools(&self) -> bool {
326 self.model.supports_tools()
327 }
328
329 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
330 self.model.supports_tools()
331 }
332
333 fn supports_images(&self) -> bool {
334 self.model.supports_images()
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("mistral/{}", self.model.id())
339 }
340
341 fn max_token_count(&self) -> u64 {
342 self.model.max_token_count()
343 }
344
345 fn max_output_tokens(&self) -> Option<u64> {
346 self.model.max_output_tokens()
347 }
348
349 fn count_tokens(
350 &self,
351 request: LanguageModelRequest,
352 cx: &App,
353 ) -> BoxFuture<'static, Result<u64>> {
354 cx.background_spawn(async move {
355 let messages = request
356 .messages
357 .into_iter()
358 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
359 role: match message.role {
360 Role::User => "user".into(),
361 Role::Assistant => "assistant".into(),
362 Role::System => "system".into(),
363 },
364 content: Some(message.string_contents()),
365 name: None,
366 function_call: None,
367 })
368 .collect::<Vec<_>>();
369
370 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
371 })
372 .boxed()
373 }
374
375 fn stream_completion(
376 &self,
377 request: LanguageModelRequest,
378 cx: &AsyncApp,
379 ) -> BoxFuture<
380 'static,
381 Result<
382 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
383 LanguageModelCompletionError,
384 >,
385 > {
386 let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
387 let stream = self.stream_completion(request, cx);
388
389 async move {
390 let stream = stream.await?;
391 let mapper = MistralEventMapper::new();
392 Ok(mapper.map_stream(stream).boxed())
393 }
394 .boxed()
395 }
396}
397
398pub fn into_mistral(
399 request: LanguageModelRequest,
400 model: mistral::Model,
401 max_output_tokens: Option<u64>,
402) -> mistral::Request {
403 let stream = true;
404
405 let mut messages = Vec::new();
406 for message in &request.messages {
407 match message.role {
408 Role::User => {
409 let mut message_content = mistral::MessageContent::empty();
410 for content in &message.content {
411 match content {
412 MessageContent::Text(text) => {
413 message_content
414 .push_part(mistral::MessagePart::Text { text: text.clone() });
415 }
416 MessageContent::Image(image_content) => {
417 if model.supports_images() {
418 message_content.push_part(mistral::MessagePart::ImageUrl {
419 image_url: image_content.to_base64_url(),
420 });
421 }
422 }
423 MessageContent::Thinking { text, .. } => {
424 if model.supports_thinking() {
425 message_content.push_part(mistral::MessagePart::Thinking {
426 thinking: vec![mistral::ThinkingPart::Text {
427 text: text.clone(),
428 }],
429 });
430 }
431 }
432 MessageContent::RedactedThinking(_) => {}
433 MessageContent::ToolUse(_) => {
434 // Tool use is not supported in User messages for Mistral
435 }
436 MessageContent::ToolResult(tool_result) => {
437 let tool_content = match &tool_result.content {
438 LanguageModelToolResultContent::Text(text) => text.to_string(),
439 LanguageModelToolResultContent::Image(_) => {
440 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
441 }
442 };
443 messages.push(mistral::RequestMessage::Tool {
444 content: tool_content,
445 tool_call_id: tool_result.tool_use_id.to_string(),
446 });
447 }
448 }
449 }
450 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
451 {
452 messages.push(mistral::RequestMessage::User {
453 content: message_content,
454 });
455 }
456 }
457 Role::Assistant => {
458 for content in &message.content {
459 match content {
460 MessageContent::Text(text) => {
461 messages.push(mistral::RequestMessage::Assistant {
462 content: Some(mistral::MessageContent::Plain {
463 content: text.clone(),
464 }),
465 tool_calls: Vec::new(),
466 });
467 }
468 MessageContent::Thinking { text, .. } => {
469 if model.supports_thinking() {
470 messages.push(mistral::RequestMessage::Assistant {
471 content: Some(mistral::MessageContent::Multipart {
472 content: vec![mistral::MessagePart::Thinking {
473 thinking: vec![mistral::ThinkingPart::Text {
474 text: text.clone(),
475 }],
476 }],
477 }),
478 tool_calls: Vec::new(),
479 });
480 }
481 }
482 MessageContent::RedactedThinking(_) => {}
483 MessageContent::Image(_) => {}
484 MessageContent::ToolUse(tool_use) => {
485 let tool_call = mistral::ToolCall {
486 id: tool_use.id.to_string(),
487 content: mistral::ToolCallContent::Function {
488 function: mistral::FunctionContent {
489 name: tool_use.name.to_string(),
490 arguments: serde_json::to_string(&tool_use.input)
491 .unwrap_or_default(),
492 },
493 },
494 };
495
496 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
497 messages.last_mut()
498 {
499 tool_calls.push(tool_call);
500 } else {
501 messages.push(mistral::RequestMessage::Assistant {
502 content: None,
503 tool_calls: vec![tool_call],
504 });
505 }
506 }
507 MessageContent::ToolResult(_) => {
508 // Tool results are not supported in Assistant messages
509 }
510 }
511 }
512 }
513 Role::System => {
514 for content in &message.content {
515 match content {
516 MessageContent::Text(text) => {
517 messages.push(mistral::RequestMessage::System {
518 content: mistral::MessageContent::Plain {
519 content: text.clone(),
520 },
521 });
522 }
523 MessageContent::Thinking { text, .. } => {
524 if model.supports_thinking() {
525 messages.push(mistral::RequestMessage::System {
526 content: mistral::MessageContent::Multipart {
527 content: vec![mistral::MessagePart::Thinking {
528 thinking: vec![mistral::ThinkingPart::Text {
529 text: text.clone(),
530 }],
531 }],
532 },
533 });
534 }
535 }
536 MessageContent::RedactedThinking(_) => {}
537 MessageContent::Image(_)
538 | MessageContent::ToolUse(_)
539 | MessageContent::ToolResult(_) => {
540 // Images and tools are not supported in System messages
541 }
542 }
543 }
544 }
545 }
546 }
547
548 mistral::Request {
549 model: model.id().to_string(),
550 messages,
551 stream,
552 max_tokens: max_output_tokens,
553 temperature: request.temperature,
554 response_format: None,
555 tool_choice: match request.tool_choice {
556 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
557 Some(mistral::ToolChoice::Auto)
558 }
559 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
560 Some(mistral::ToolChoice::Any)
561 }
562 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
563 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
564 _ => None,
565 },
566 parallel_tool_calls: if !request.tools.is_empty() {
567 Some(false)
568 } else {
569 None
570 },
571 tools: request
572 .tools
573 .into_iter()
574 .map(|tool| mistral::ToolDefinition::Function {
575 function: mistral::FunctionDefinition {
576 name: tool.name,
577 description: Some(tool.description),
578 parameters: Some(tool.input_schema),
579 },
580 })
581 .collect(),
582 }
583}
584
585pub struct MistralEventMapper {
586 tool_calls_by_index: HashMap<usize, RawToolCall>,
587}
588
589impl MistralEventMapper {
590 pub fn new() -> Self {
591 Self {
592 tool_calls_by_index: HashMap::default(),
593 }
594 }
595
596 pub fn map_stream(
597 mut self,
598 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
599 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
600 {
601 events.flat_map(move |event| {
602 futures::stream::iter(match event {
603 Ok(event) => self.map_event(event),
604 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
605 })
606 })
607 }
608
609 pub fn map_event(
610 &mut self,
611 event: mistral::StreamResponse,
612 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
613 let Some(choice) = event.choices.first() else {
614 return vec![Err(LanguageModelCompletionError::from(anyhow!(
615 "Response contained no choices"
616 )))];
617 };
618
619 let mut events = Vec::new();
620 if let Some(content) = choice.delta.content.as_ref() {
621 match content {
622 mistral::MessageContentDelta::Text(text) => {
623 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
624 }
625 mistral::MessageContentDelta::Parts(parts) => {
626 for part in parts {
627 match part {
628 mistral::MessagePart::Text { text } => {
629 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
630 }
631 mistral::MessagePart::Thinking { thinking } => {
632 for tp in thinking.iter().cloned() {
633 match tp {
634 mistral::ThinkingPart::Text { text } => {
635 events.push(Ok(
636 LanguageModelCompletionEvent::Thinking {
637 text,
638 signature: None,
639 },
640 ));
641 }
642 }
643 }
644 }
645 mistral::MessagePart::ImageUrl { .. } => {
646 // We currently don't emit a separate event for images in responses.
647 }
648 }
649 }
650 }
651 }
652 }
653
654 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
655 for tool_call in tool_calls {
656 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
657
658 if let Some(tool_id) = tool_call.id.clone() {
659 entry.id = tool_id;
660 }
661
662 if let Some(function) = tool_call.function.as_ref() {
663 if let Some(name) = function.name.clone() {
664 entry.name = name;
665 }
666
667 if let Some(arguments) = function.arguments.clone() {
668 entry.arguments.push_str(&arguments);
669 }
670 }
671 }
672 }
673
674 if let Some(usage) = event.usage {
675 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
676 input_tokens: usage.prompt_tokens,
677 output_tokens: usage.completion_tokens,
678 cache_creation_input_tokens: 0,
679 cache_read_input_tokens: 0,
680 })));
681 }
682
683 if let Some(finish_reason) = choice.finish_reason.as_deref() {
684 match finish_reason {
685 "stop" => {
686 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
687 }
688 "tool_calls" => {
689 events.extend(self.process_tool_calls());
690 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
691 }
692 unexpected => {
693 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
694 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
695 }
696 }
697 }
698
699 events
700 }
701
702 fn process_tool_calls(
703 &mut self,
704 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
705 let mut results = Vec::new();
706
707 for (_, tool_call) in self.tool_calls_by_index.drain() {
708 if tool_call.id.is_empty() || tool_call.name.is_empty() {
709 results.push(Err(LanguageModelCompletionError::from(anyhow!(
710 "Received incomplete tool call: missing id or name"
711 ))));
712 continue;
713 }
714
715 match serde_json::Value::from_str(&tool_call.arguments) {
716 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
717 LanguageModelToolUse {
718 id: tool_call.id.into(),
719 name: tool_call.name.into(),
720 is_input_complete: true,
721 input,
722 raw_input: tool_call.arguments,
723 thought_signature: None,
724 },
725 ))),
726 Err(error) => {
727 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
728 id: tool_call.id.into(),
729 tool_name: tool_call.name.into(),
730 raw_input: tool_call.arguments.into(),
731 json_parse_error: error.to_string(),
732 }))
733 }
734 }
735 }
736
737 results
738 }
739}
740
741#[derive(Default)]
742struct RawToolCall {
743 id: String,
744 name: String,
745 arguments: String,
746}
747
748struct ConfigurationView {
749 api_key_editor: Entity<InputField>,
750 codestral_api_key_editor: Entity<InputField>,
751 state: Entity<State>,
752 load_credentials_task: Option<Task<()>>,
753}
754
755impl ConfigurationView {
756 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
757 let api_key_editor =
758 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
759 let codestral_api_key_editor =
760 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
761
762 cx.observe(&state, |_, _, cx| {
763 cx.notify();
764 })
765 .detach();
766
767 let load_credentials_task = Some(cx.spawn_in(window, {
768 let state = state.clone();
769 async move |this, cx| {
770 if let Some(task) = state
771 .update(cx, |state, cx| state.authenticate(cx))
772 .log_err()
773 {
774 // We don't log an error, because "not signed in" is also an error.
775 let _ = task.await;
776 }
777 if let Some(task) = state
778 .update(cx, |state, cx| state.authenticate_codestral(cx))
779 .log_err()
780 {
781 let _ = task.await;
782 }
783
784 this.update(cx, |this, cx| {
785 this.load_credentials_task = None;
786 cx.notify();
787 })
788 .log_err();
789 }
790 }));
791
792 Self {
793 api_key_editor,
794 codestral_api_key_editor,
795 state,
796 load_credentials_task,
797 }
798 }
799
800 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
801 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
802 if api_key.is_empty() {
803 return;
804 }
805
806 // url changes can cause the editor to be displayed again
807 self.api_key_editor
808 .update(cx, |editor, cx| editor.set_text("", window, cx));
809
810 let state = self.state.clone();
811 cx.spawn_in(window, async move |_, cx| {
812 state
813 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
814 .await
815 })
816 .detach_and_log_err(cx);
817 }
818
819 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
820 self.api_key_editor
821 .update(cx, |editor, cx| editor.set_text("", window, cx));
822
823 let state = self.state.clone();
824 cx.spawn_in(window, async move |_, cx| {
825 state
826 .update(cx, |state, cx| state.set_api_key(None, cx))?
827 .await
828 })
829 .detach_and_log_err(cx);
830 }
831
832 fn save_codestral_api_key(
833 &mut self,
834 _: &menu::Confirm,
835 window: &mut Window,
836 cx: &mut Context<Self>,
837 ) {
838 let api_key = self
839 .codestral_api_key_editor
840 .read(cx)
841 .text(cx)
842 .trim()
843 .to_string();
844 if api_key.is_empty() {
845 return;
846 }
847
848 // url changes can cause the editor to be displayed again
849 self.codestral_api_key_editor
850 .update(cx, |editor, cx| editor.set_text("", window, cx));
851
852 let state = self.state.clone();
853 cx.spawn_in(window, async move |_, cx| {
854 state
855 .update(cx, |state, cx| {
856 state.set_codestral_api_key(Some(api_key), cx)
857 })?
858 .await?;
859 cx.update(|_window, cx| {
860 set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
861 })
862 })
863 .detach_and_log_err(cx);
864 }
865
866 fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
867 self.codestral_api_key_editor
868 .update(cx, |editor, cx| editor.set_text("", window, cx));
869
870 let state = self.state.clone();
871 cx.spawn_in(window, async move |_, cx| {
872 state
873 .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
874 .await?;
875 cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
876 })
877 .detach_and_log_err(cx);
878 }
879
880 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
881 !self.state.read(cx).is_authenticated()
882 }
883
884 fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
885 let key_state = &self.state.read(cx).codestral_api_key_state;
886 let should_show_editor = !key_state.has_key();
887 let env_var_set = key_state.is_from_env_var();
888 let configured_card_label = if env_var_set {
889 format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
890 } else {
891 "Codestral API key configured".to_string()
892 };
893
894 if should_show_editor {
895 v_flex()
896 .id("codestral")
897 .size_full()
898 .mt_2()
899 .on_action(cx.listener(Self::save_codestral_api_key))
900 .child(Label::new(
901 "To use Codestral as an edit prediction provider, \
902 you need to add a Codestral-specific API key. Follow these steps:",
903 ))
904 .child(
905 List::new()
906 .child(InstructionListItem::new(
907 "Create one by visiting",
908 Some("the Codestral section of Mistral's console"),
909 Some("https://console.mistral.ai/codestral"),
910 ))
911 .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
912 )
913 .child(self.codestral_api_key_editor.clone())
914 .child(
915 Label::new(
916 format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
917 )
918 .size(LabelSize::Small).color(Color::Muted),
919 ).into_any()
920 } else {
921 ConfiguredApiCard::new(configured_card_label)
922 .disabled(env_var_set)
923 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
924 .when(env_var_set, |this| {
925 this.tooltip_label(format!(
926 "To reset your API key, \
927 unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
928 ))
929 })
930 .on_click(
931 cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
932 )
933 .into_any_element()
934 }
935 }
936}
937
938impl Render for ConfigurationView {
939 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
940 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
941 let configured_card_label = if env_var_set {
942 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
943 } else {
944 let api_url = MistralLanguageModelProvider::api_url(cx);
945 if api_url == MISTRAL_API_URL {
946 "API key configured".to_string()
947 } else {
948 format!("API key configured for {}", api_url)
949 }
950 };
951
952 if self.load_credentials_task.is_some() {
953 div().child(Label::new("Loading credentials...")).into_any()
954 } else if self.should_render_api_key_editor(cx) {
955 v_flex()
956 .size_full()
957 .on_action(cx.listener(Self::save_api_key))
958 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
959 .child(
960 List::new()
961 .child(InstructionListItem::new(
962 "Create one by visiting",
963 Some("Mistral's console"),
964 Some("https://console.mistral.ai/api-keys"),
965 ))
966 .child(InstructionListItem::text_only(
967 "Ensure your Mistral account has credits",
968 ))
969 .child(InstructionListItem::text_only(
970 "Paste your API key below and hit enter to start using the assistant",
971 )),
972 )
973 .child(self.api_key_editor.clone())
974 .child(
975 Label::new(
976 format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
977 )
978 .size(LabelSize::Small).color(Color::Muted),
979 )
980 .child(self.render_codestral_api_key_editor(cx))
981 .into_any()
982 } else {
983 v_flex()
984 .size_full()
985 .gap_1()
986 .child(
987 ConfiguredApiCard::new(configured_card_label)
988 .disabled(env_var_set)
989 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
990 .when(env_var_set, |this| {
991 this.tooltip_label(format!(
992 "To reset your API key, \
993 unset the {API_KEY_ENV_VAR_NAME} environment variable."
994 ))
995 }),
996 )
997 .child(self.render_codestral_api_key_editor(cx))
998 .into_any()
999 }
1000 }
1001}
1002
1003fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
1004 let fs = <dyn Fs>::global(cx);
1005 update_settings_file(fs, cx, move |settings, _| {
1006 settings
1007 .project
1008 .all_languages
1009 .features
1010 .get_or_insert_default()
1011 .edit_prediction_provider = Some(provider);
1012 });
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use super::*;
1018 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
1019
1020 #[test]
1021 fn test_into_mistral_basic_conversion() {
1022 let request = LanguageModelRequest {
1023 messages: vec![
1024 LanguageModelRequestMessage {
1025 role: Role::System,
1026 content: vec![MessageContent::Text("System prompt".into())],
1027 cache: false,
1028 reasoning_details: None,
1029 },
1030 LanguageModelRequestMessage {
1031 role: Role::User,
1032 content: vec![MessageContent::Text("Hello".into())],
1033 cache: false,
1034 reasoning_details: None,
1035 },
1036 ],
1037 temperature: Some(0.5),
1038 tools: vec![],
1039 tool_choice: None,
1040 thread_id: None,
1041 prompt_id: None,
1042 intent: None,
1043 mode: None,
1044 stop: vec![],
1045 thinking_allowed: true,
1046 };
1047
1048 let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
1049
1050 assert_eq!(mistral_request.model, "mistral-small-latest");
1051 assert_eq!(mistral_request.temperature, Some(0.5));
1052 assert_eq!(mistral_request.messages.len(), 2);
1053 assert!(mistral_request.stream);
1054 }
1055
1056 #[test]
1057 fn test_into_mistral_with_image() {
1058 let request = LanguageModelRequest {
1059 messages: vec![LanguageModelRequestMessage {
1060 role: Role::User,
1061 content: vec![
1062 MessageContent::Text("What's in this image?".into()),
1063 MessageContent::Image(LanguageModelImage {
1064 source: "base64data".into(),
1065 size: Default::default(),
1066 }),
1067 ],
1068 cache: false,
1069 reasoning_details: None,
1070 }],
1071 tools: vec![],
1072 tool_choice: None,
1073 temperature: None,
1074 thread_id: None,
1075 prompt_id: None,
1076 intent: None,
1077 mode: None,
1078 stop: vec![],
1079 thinking_allowed: true,
1080 };
1081
1082 let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1083
1084 assert_eq!(mistral_request.messages.len(), 1);
1085 assert!(matches!(
1086 &mistral_request.messages[0],
1087 mistral::RequestMessage::User {
1088 content: mistral::MessageContent::Multipart { .. }
1089 }
1090 ));
1091
1092 if let mistral::RequestMessage::User {
1093 content: mistral::MessageContent::Multipart { content },
1094 } = &mistral_request.messages[0]
1095 {
1096 assert_eq!(content.len(), 2);
1097 assert!(matches!(
1098 &content[0],
1099 mistral::MessagePart::Text { text } if text == "What's in this image?"
1100 ));
1101 assert!(matches!(
1102 &content[1],
1103 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1104 ));
1105 }
1106 }
1107}