1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use fs::Fs;
4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
15pub use settings::MistralAvailableModel as AvailableModel;
16use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{List, prelude::*};
23use ui_input::InputField;
24use util::ResultExt;
25use zed_env_vars::{EnvVar, env_var};
26
27use crate::ui::ConfiguredApiCard;
28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
29
30const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
31const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
32
33const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
35
36const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
37static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
38
39#[derive(Default, Clone, Debug, PartialEq)]
40pub struct MistralSettings {
41 pub api_url: String,
42 pub available_models: Vec<AvailableModel>,
43}
44
45pub struct MistralLanguageModelProvider {
46 http_client: Arc<dyn HttpClient>,
47 state: Entity<State>,
48}
49
50pub struct State {
51 api_key_state: ApiKeyState,
52 codestral_api_key_state: ApiKeyState,
53}
54
55impl State {
56 fn is_authenticated(&self) -> bool {
57 self.api_key_state.has_key()
58 }
59
60 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
61 let api_url = MistralLanguageModelProvider::api_url(cx);
62 self.api_key_state
63 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
64 }
65
66 fn set_codestral_api_key(
67 &mut self,
68 api_key: Option<String>,
69 cx: &mut Context<Self>,
70 ) -> Task<Result<()>> {
71 self.codestral_api_key_state.store(
72 CODESTRAL_API_URL.into(),
73 api_key,
74 |this| &mut this.codestral_api_key_state,
75 cx,
76 )
77 }
78
79 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
80 let api_url = MistralLanguageModelProvider::api_url(cx);
81 self.api_key_state.load_if_needed(
82 api_url,
83 &API_KEY_ENV_VAR,
84 |this| &mut this.api_key_state,
85 cx,
86 )
87 }
88
89 fn authenticate_codestral(
90 &mut self,
91 cx: &mut Context<Self>,
92 ) -> Task<Result<(), AuthenticateError>> {
93 self.codestral_api_key_state.load_if_needed(
94 CODESTRAL_API_URL.into(),
95 &CODESTRAL_API_KEY_ENV_VAR,
96 |this| &mut this.codestral_api_key_state,
97 cx,
98 )
99 }
100}
101
102struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
103
104impl Global for GlobalMistralLanguageModelProvider {}
105
106impl MistralLanguageModelProvider {
107 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
108 cx.try_global::<GlobalMistralLanguageModelProvider>()
109 .map(|this| &this.0)
110 }
111
112 pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
113 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
114 return this.0.clone();
115 }
116 let state = cx.new(|cx| {
117 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
118 let api_url = Self::api_url(cx);
119 this.api_key_state.handle_url_change(
120 api_url,
121 &API_KEY_ENV_VAR,
122 |this| &mut this.api_key_state,
123 cx,
124 );
125 cx.notify();
126 })
127 .detach();
128 State {
129 api_key_state: ApiKeyState::new(Self::api_url(cx)),
130 codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
131 }
132 });
133
134 let this = Arc::new(Self { http_client, state });
135 cx.set_global(GlobalMistralLanguageModelProvider(this));
136 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
137 }
138
139 pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
140 self.state
141 .update(cx, |state, cx| state.authenticate_codestral(cx))
142 }
143
144 pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
145 self.state.read(cx).codestral_api_key_state.key(url)
146 }
147
148 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
149 Arc::new(MistralLanguageModel {
150 id: LanguageModelId::from(model.id().to_string()),
151 model,
152 state: self.state.clone(),
153 http_client: self.http_client.clone(),
154 request_limiter: RateLimiter::new(4),
155 })
156 }
157
158 fn settings(cx: &App) -> &MistralSettings {
159 &crate::AllLanguageModelSettings::get_global(cx).mistral
160 }
161
162 fn api_url(cx: &App) -> SharedString {
163 let api_url = &Self::settings(cx).api_url;
164 if api_url.is_empty() {
165 mistral::MISTRAL_API_URL.into()
166 } else {
167 SharedString::new(api_url.as_str())
168 }
169 }
170}
171
172impl LanguageModelProviderState for MistralLanguageModelProvider {
173 type ObservableEntity = State;
174
175 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
176 Some(self.state.clone())
177 }
178}
179
180impl LanguageModelProvider for MistralLanguageModelProvider {
181 fn id(&self) -> LanguageModelProviderId {
182 PROVIDER_ID
183 }
184
185 fn name(&self) -> LanguageModelProviderName {
186 PROVIDER_NAME
187 }
188
189 fn icon(&self) -> IconName {
190 IconName::AiMistral
191 }
192
193 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
194 Some(self.create_language_model(mistral::Model::default()))
195 }
196
197 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
198 Some(self.create_language_model(mistral::Model::default_fast()))
199 }
200
201 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
202 let mut models = BTreeMap::default();
203
204 // Add base models from mistral::Model::iter()
205 for model in mistral::Model::iter() {
206 if !matches!(model, mistral::Model::Custom { .. }) {
207 models.insert(model.id().to_string(), model);
208 }
209 }
210
211 // Override with available models from settings
212 for model in &Self::settings(cx).available_models {
213 models.insert(
214 model.name.clone(),
215 mistral::Model::Custom {
216 name: model.name.clone(),
217 display_name: model.display_name.clone(),
218 max_tokens: model.max_tokens,
219 max_output_tokens: model.max_output_tokens,
220 max_completion_tokens: model.max_completion_tokens,
221 supports_tools: model.supports_tools,
222 supports_images: model.supports_images,
223 supports_thinking: model.supports_thinking,
224 },
225 );
226 }
227
228 models
229 .into_values()
230 .map(|model| {
231 Arc::new(MistralLanguageModel {
232 id: LanguageModelId::from(model.id().to_string()),
233 model,
234 state: self.state.clone(),
235 http_client: self.http_client.clone(),
236 request_limiter: RateLimiter::new(4),
237 }) as Arc<dyn LanguageModel>
238 })
239 .collect()
240 }
241
242 fn is_authenticated(&self, cx: &App) -> bool {
243 self.state.read(cx).is_authenticated()
244 }
245
246 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
247 self.state.update(cx, |state, cx| state.authenticate(cx))
248 }
249
250 fn configuration_view(
251 &self,
252 _target_agent: language_model::ConfigurationViewTargetAgent,
253 window: &mut Window,
254 cx: &mut App,
255 ) -> AnyView {
256 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
257 .into()
258 }
259
260 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
261 self.state
262 .update(cx, |state, cx| state.set_api_key(None, cx))
263 }
264}
265
266pub struct MistralLanguageModel {
267 id: LanguageModelId,
268 model: mistral::Model,
269 state: Entity<State>,
270 http_client: Arc<dyn HttpClient>,
271 request_limiter: RateLimiter,
272}
273
274impl MistralLanguageModel {
275 fn stream_completion(
276 &self,
277 request: mistral::Request,
278 cx: &AsyncApp,
279 ) -> BoxFuture<
280 'static,
281 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
282 > {
283 let http_client = self.http_client.clone();
284
285 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
286 let api_url = MistralLanguageModelProvider::api_url(cx);
287 (state.api_key_state.key(&api_url), api_url)
288 }) else {
289 return future::ready(Err(anyhow!("App state dropped"))).boxed();
290 };
291
292 let future = self.request_limiter.stream(async move {
293 let Some(api_key) = api_key else {
294 return Err(LanguageModelCompletionError::NoApiKey {
295 provider: PROVIDER_NAME,
296 });
297 };
298 let request =
299 mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
300 let response = request.await?;
301 Ok(response)
302 });
303
304 async move { Ok(future.await?.boxed()) }.boxed()
305 }
306}
307
308impl LanguageModel for MistralLanguageModel {
309 fn id(&self) -> LanguageModelId {
310 self.id.clone()
311 }
312
313 fn name(&self) -> LanguageModelName {
314 LanguageModelName::from(self.model.display_name().to_string())
315 }
316
317 fn provider_id(&self) -> LanguageModelProviderId {
318 PROVIDER_ID
319 }
320
321 fn provider_name(&self) -> LanguageModelProviderName {
322 PROVIDER_NAME
323 }
324
325 fn supports_tools(&self) -> bool {
326 self.model.supports_tools()
327 }
328
329 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
330 self.model.supports_tools()
331 }
332
333 fn supports_images(&self) -> bool {
334 self.model.supports_images()
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("mistral/{}", self.model.id())
339 }
340
341 fn max_token_count(&self) -> u64 {
342 self.model.max_token_count()
343 }
344
345 fn max_output_tokens(&self) -> Option<u64> {
346 self.model.max_output_tokens()
347 }
348
349 fn count_tokens(
350 &self,
351 request: LanguageModelRequest,
352 cx: &App,
353 ) -> BoxFuture<'static, Result<u64>> {
354 cx.background_spawn(async move {
355 let messages = request
356 .messages
357 .into_iter()
358 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
359 role: match message.role {
360 Role::User => "user".into(),
361 Role::Assistant => "assistant".into(),
362 Role::System => "system".into(),
363 },
364 content: Some(message.string_contents()),
365 name: None,
366 function_call: None,
367 })
368 .collect::<Vec<_>>();
369
370 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
371 })
372 .boxed()
373 }
374
375 fn stream_completion(
376 &self,
377 request: LanguageModelRequest,
378 cx: &AsyncApp,
379 ) -> BoxFuture<
380 'static,
381 Result<
382 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
383 LanguageModelCompletionError,
384 >,
385 > {
386 let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
387 let stream = self.stream_completion(request, cx);
388
389 async move {
390 let stream = stream.await?;
391 let mapper = MistralEventMapper::new();
392 Ok(mapper.map_stream(stream).boxed())
393 }
394 .boxed()
395 }
396}
397
398pub fn into_mistral(
399 request: LanguageModelRequest,
400 model: mistral::Model,
401 max_output_tokens: Option<u64>,
402) -> mistral::Request {
403 let stream = true;
404
405 let mut messages = Vec::new();
406 for message in &request.messages {
407 match message.role {
408 Role::User => {
409 let mut message_content = mistral::MessageContent::empty();
410 for content in &message.content {
411 match content {
412 MessageContent::Text(text) => {
413 message_content
414 .push_part(mistral::MessagePart::Text { text: text.clone() });
415 }
416 MessageContent::Image(image_content) => {
417 if model.supports_images() {
418 message_content.push_part(mistral::MessagePart::ImageUrl {
419 image_url: image_content.to_base64_url(),
420 });
421 }
422 }
423 MessageContent::Thinking { text, .. } => {
424 if model.supports_thinking() {
425 message_content.push_part(mistral::MessagePart::Thinking {
426 thinking: vec![mistral::ThinkingPart::Text {
427 text: text.clone(),
428 }],
429 });
430 }
431 }
432 MessageContent::RedactedThinking(_) => {}
433 MessageContent::ToolUse(_) => {
434 // Tool use is not supported in User messages for Mistral
435 }
436 MessageContent::ToolResult(tool_result) => {
437 let tool_content = match &tool_result.content {
438 LanguageModelToolResultContent::Text(text) => text.to_string(),
439 LanguageModelToolResultContent::Image(_) => {
440 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
441 }
442 };
443 messages.push(mistral::RequestMessage::Tool {
444 content: tool_content,
445 tool_call_id: tool_result.tool_use_id.to_string(),
446 });
447 }
448 }
449 }
450 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
451 {
452 messages.push(mistral::RequestMessage::User {
453 content: message_content,
454 });
455 }
456 }
457 Role::Assistant => {
458 for content in &message.content {
459 match content {
460 MessageContent::Text(text) => {
461 messages.push(mistral::RequestMessage::Assistant {
462 content: Some(mistral::MessageContent::Plain {
463 content: text.clone(),
464 }),
465 tool_calls: Vec::new(),
466 });
467 }
468 MessageContent::Thinking { text, .. } => {
469 if model.supports_thinking() {
470 messages.push(mistral::RequestMessage::Assistant {
471 content: Some(mistral::MessageContent::Multipart {
472 content: vec![mistral::MessagePart::Thinking {
473 thinking: vec![mistral::ThinkingPart::Text {
474 text: text.clone(),
475 }],
476 }],
477 }),
478 tool_calls: Vec::new(),
479 });
480 }
481 }
482 MessageContent::RedactedThinking(_) => {}
483 MessageContent::Image(_) => {}
484 MessageContent::ToolUse(tool_use) => {
485 let tool_call = mistral::ToolCall {
486 id: tool_use.id.to_string(),
487 content: mistral::ToolCallContent::Function {
488 function: mistral::FunctionContent {
489 name: tool_use.name.to_string(),
490 arguments: serde_json::to_string(&tool_use.input)
491 .unwrap_or_default(),
492 },
493 },
494 };
495
496 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
497 messages.last_mut()
498 {
499 tool_calls.push(tool_call);
500 } else {
501 messages.push(mistral::RequestMessage::Assistant {
502 content: None,
503 tool_calls: vec![tool_call],
504 });
505 }
506 }
507 MessageContent::ToolResult(_) => {
508 // Tool results are not supported in Assistant messages
509 }
510 }
511 }
512 }
513 Role::System => {
514 for content in &message.content {
515 match content {
516 MessageContent::Text(text) => {
517 messages.push(mistral::RequestMessage::System {
518 content: mistral::MessageContent::Plain {
519 content: text.clone(),
520 },
521 });
522 }
523 MessageContent::Thinking { text, .. } => {
524 if model.supports_thinking() {
525 messages.push(mistral::RequestMessage::System {
526 content: mistral::MessageContent::Multipart {
527 content: vec![mistral::MessagePart::Thinking {
528 thinking: vec![mistral::ThinkingPart::Text {
529 text: text.clone(),
530 }],
531 }],
532 },
533 });
534 }
535 }
536 MessageContent::RedactedThinking(_) => {}
537 MessageContent::Image(_)
538 | MessageContent::ToolUse(_)
539 | MessageContent::ToolResult(_) => {
540 // Images and tools are not supported in System messages
541 }
542 }
543 }
544 }
545 }
546 }
547
548 mistral::Request {
549 model: model.id().to_string(),
550 messages,
551 stream,
552 max_tokens: max_output_tokens,
553 temperature: request.temperature,
554 response_format: None,
555 tool_choice: match request.tool_choice {
556 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
557 Some(mistral::ToolChoice::Auto)
558 }
559 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
560 Some(mistral::ToolChoice::Any)
561 }
562 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
563 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
564 _ => None,
565 },
566 parallel_tool_calls: if !request.tools.is_empty() {
567 Some(false)
568 } else {
569 None
570 },
571 tools: request
572 .tools
573 .into_iter()
574 .map(|tool| mistral::ToolDefinition::Function {
575 function: mistral::FunctionDefinition {
576 name: tool.name,
577 description: Some(tool.description),
578 parameters: Some(tool.input_schema),
579 },
580 })
581 .collect(),
582 }
583}
584
585pub struct MistralEventMapper {
586 tool_calls_by_index: HashMap<usize, RawToolCall>,
587}
588
589impl MistralEventMapper {
590 pub fn new() -> Self {
591 Self {
592 tool_calls_by_index: HashMap::default(),
593 }
594 }
595
596 pub fn map_stream(
597 mut self,
598 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
599 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
600 {
601 events.flat_map(move |event| {
602 futures::stream::iter(match event {
603 Ok(event) => self.map_event(event),
604 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
605 })
606 })
607 }
608
609 pub fn map_event(
610 &mut self,
611 event: mistral::StreamResponse,
612 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
613 let Some(choice) = event.choices.first() else {
614 return vec![Err(LanguageModelCompletionError::from(anyhow!(
615 "Response contained no choices"
616 )))];
617 };
618
619 let mut events = Vec::new();
620 if let Some(content) = choice.delta.content.as_ref() {
621 match content {
622 mistral::MessageContentDelta::Text(text) => {
623 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
624 }
625 mistral::MessageContentDelta::Parts(parts) => {
626 for part in parts {
627 match part {
628 mistral::MessagePart::Text { text } => {
629 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
630 }
631 mistral::MessagePart::Thinking { thinking } => {
632 for tp in thinking.iter().cloned() {
633 match tp {
634 mistral::ThinkingPart::Text { text } => {
635 events.push(Ok(
636 LanguageModelCompletionEvent::Thinking {
637 text,
638 signature: None,
639 },
640 ));
641 }
642 }
643 }
644 }
645 mistral::MessagePart::ImageUrl { .. } => {
646 // We currently don't emit a separate event for images in responses.
647 }
648 }
649 }
650 }
651 }
652 }
653
654 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
655 for tool_call in tool_calls {
656 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
657
658 if let Some(tool_id) = tool_call.id.clone() {
659 entry.id = tool_id;
660 }
661
662 if let Some(function) = tool_call.function.as_ref() {
663 if let Some(name) = function.name.clone() {
664 entry.name = name;
665 }
666
667 if let Some(arguments) = function.arguments.clone() {
668 entry.arguments.push_str(&arguments);
669 }
670 }
671 }
672 }
673
674 if let Some(usage) = event.usage {
675 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
676 input_tokens: usage.prompt_tokens,
677 output_tokens: usage.completion_tokens,
678 cache_creation_input_tokens: 0,
679 cache_read_input_tokens: 0,
680 })));
681 }
682
683 if let Some(finish_reason) = choice.finish_reason.as_deref() {
684 match finish_reason {
685 "stop" => {
686 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
687 }
688 "tool_calls" => {
689 events.extend(self.process_tool_calls());
690 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
691 }
692 unexpected => {
693 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
694 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
695 }
696 }
697 }
698
699 events
700 }
701
702 fn process_tool_calls(
703 &mut self,
704 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
705 let mut results = Vec::new();
706
707 for (_, tool_call) in self.tool_calls_by_index.drain() {
708 if tool_call.id.is_empty() || tool_call.name.is_empty() {
709 results.push(Err(LanguageModelCompletionError::from(anyhow!(
710 "Received incomplete tool call: missing id or name"
711 ))));
712 continue;
713 }
714
715 match serde_json::Value::from_str(&tool_call.arguments) {
716 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
717 LanguageModelToolUse {
718 id: tool_call.id.into(),
719 name: tool_call.name.into(),
720 is_input_complete: true,
721 input,
722 raw_input: tool_call.arguments,
723 },
724 ))),
725 Err(error) => {
726 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
727 id: tool_call.id.into(),
728 tool_name: tool_call.name.into(),
729 raw_input: tool_call.arguments.into(),
730 json_parse_error: error.to_string(),
731 }))
732 }
733 }
734 }
735
736 results
737 }
738}
739
740#[derive(Default)]
741struct RawToolCall {
742 id: String,
743 name: String,
744 arguments: String,
745}
746
747struct ConfigurationView {
748 api_key_editor: Entity<InputField>,
749 codestral_api_key_editor: Entity<InputField>,
750 state: Entity<State>,
751 load_credentials_task: Option<Task<()>>,
752}
753
754impl ConfigurationView {
755 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
756 let api_key_editor =
757 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
758 let codestral_api_key_editor =
759 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
760
761 cx.observe(&state, |_, _, cx| {
762 cx.notify();
763 })
764 .detach();
765
766 let load_credentials_task = Some(cx.spawn_in(window, {
767 let state = state.clone();
768 async move |this, cx| {
769 if let Some(task) = state
770 .update(cx, |state, cx| state.authenticate(cx))
771 .log_err()
772 {
773 // We don't log an error, because "not signed in" is also an error.
774 let _ = task.await;
775 }
776 if let Some(task) = state
777 .update(cx, |state, cx| state.authenticate_codestral(cx))
778 .log_err()
779 {
780 let _ = task.await;
781 }
782
783 this.update(cx, |this, cx| {
784 this.load_credentials_task = None;
785 cx.notify();
786 })
787 .log_err();
788 }
789 }));
790
791 Self {
792 api_key_editor,
793 codestral_api_key_editor,
794 state,
795 load_credentials_task,
796 }
797 }
798
799 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
800 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
801 if api_key.is_empty() {
802 return;
803 }
804
805 // url changes can cause the editor to be displayed again
806 self.api_key_editor
807 .update(cx, |editor, cx| editor.set_text("", window, cx));
808
809 let state = self.state.clone();
810 cx.spawn_in(window, async move |_, cx| {
811 state
812 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
813 .await
814 })
815 .detach_and_log_err(cx);
816 }
817
818 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
819 self.api_key_editor
820 .update(cx, |editor, cx| editor.set_text("", window, cx));
821
822 let state = self.state.clone();
823 cx.spawn_in(window, async move |_, cx| {
824 state
825 .update(cx, |state, cx| state.set_api_key(None, cx))?
826 .await
827 })
828 .detach_and_log_err(cx);
829 }
830
831 fn save_codestral_api_key(
832 &mut self,
833 _: &menu::Confirm,
834 window: &mut Window,
835 cx: &mut Context<Self>,
836 ) {
837 let api_key = self
838 .codestral_api_key_editor
839 .read(cx)
840 .text(cx)
841 .trim()
842 .to_string();
843 if api_key.is_empty() {
844 return;
845 }
846
847 // url changes can cause the editor to be displayed again
848 self.codestral_api_key_editor
849 .update(cx, |editor, cx| editor.set_text("", window, cx));
850
851 let state = self.state.clone();
852 cx.spawn_in(window, async move |_, cx| {
853 state
854 .update(cx, |state, cx| {
855 state.set_codestral_api_key(Some(api_key), cx)
856 })?
857 .await?;
858 cx.update(|_window, cx| {
859 set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
860 })
861 })
862 .detach_and_log_err(cx);
863 }
864
865 fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
866 self.codestral_api_key_editor
867 .update(cx, |editor, cx| editor.set_text("", window, cx));
868
869 let state = self.state.clone();
870 cx.spawn_in(window, async move |_, cx| {
871 state
872 .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
873 .await?;
874 cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
875 })
876 .detach_and_log_err(cx);
877 }
878
879 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
880 !self.state.read(cx).is_authenticated()
881 }
882
883 fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
884 let key_state = &self.state.read(cx).codestral_api_key_state;
885 let should_show_editor = !key_state.has_key();
886 let env_var_set = key_state.is_from_env_var();
887 let configured_card_label = if env_var_set {
888 format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
889 } else {
890 "Codestral API key configured".to_string()
891 };
892
893 if should_show_editor {
894 v_flex()
895 .id("codestral")
896 .size_full()
897 .mt_2()
898 .on_action(cx.listener(Self::save_codestral_api_key))
899 .child(Label::new(
900 "To use Codestral as an edit prediction provider, \
901 you need to add a Codestral-specific API key. Follow these steps:",
902 ))
903 .child(
904 List::new()
905 .child(InstructionListItem::new(
906 "Create one by visiting",
907 Some("the Codestral section of Mistral's console"),
908 Some("https://console.mistral.ai/codestral"),
909 ))
910 .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
911 )
912 .child(self.codestral_api_key_editor.clone())
913 .child(
914 Label::new(
915 format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
916 )
917 .size(LabelSize::Small).color(Color::Muted),
918 ).into_any()
919 } else {
920 ConfiguredApiCard::new(configured_card_label)
921 .disabled(env_var_set)
922 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
923 .when(env_var_set, |this| {
924 this.tooltip_label(format!(
925 "To reset your API key, \
926 unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
927 ))
928 })
929 .on_click(
930 cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
931 )
932 .into_any_element()
933 }
934 }
935}
936
937impl Render for ConfigurationView {
938 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
939 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
940 let configured_card_label = if env_var_set {
941 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
942 } else {
943 let api_url = MistralLanguageModelProvider::api_url(cx);
944 if api_url == MISTRAL_API_URL {
945 "API key configured".to_string()
946 } else {
947 format!("API key configured for {}", api_url)
948 }
949 };
950
951 if self.load_credentials_task.is_some() {
952 div().child(Label::new("Loading credentials...")).into_any()
953 } else if self.should_render_api_key_editor(cx) {
954 v_flex()
955 .size_full()
956 .on_action(cx.listener(Self::save_api_key))
957 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
958 .child(
959 List::new()
960 .child(InstructionListItem::new(
961 "Create one by visiting",
962 Some("Mistral's console"),
963 Some("https://console.mistral.ai/api-keys"),
964 ))
965 .child(InstructionListItem::text_only(
966 "Ensure your Mistral account has credits",
967 ))
968 .child(InstructionListItem::text_only(
969 "Paste your API key below and hit enter to start using the assistant",
970 )),
971 )
972 .child(self.api_key_editor.clone())
973 .child(
974 Label::new(
975 format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
976 )
977 .size(LabelSize::Small).color(Color::Muted),
978 )
979 .child(self.render_codestral_api_key_editor(cx))
980 .into_any()
981 } else {
982 v_flex()
983 .size_full()
984 .gap_1()
985 .child(
986 ConfiguredApiCard::new(configured_card_label)
987 .disabled(env_var_set)
988 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
989 .when(env_var_set, |this| {
990 this.tooltip_label(format!(
991 "To reset your API key, \
992 unset the {API_KEY_ENV_VAR_NAME} environment variable."
993 ))
994 }),
995 )
996 .child(self.render_codestral_api_key_editor(cx))
997 .into_any()
998 }
999 }
1000}
1001
1002fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
1003 let fs = <dyn Fs>::global(cx);
1004 update_settings_file(fs, cx, move |settings, _| {
1005 settings
1006 .project
1007 .all_languages
1008 .features
1009 .get_or_insert_default()
1010 .edit_prediction_provider = Some(provider);
1011 });
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use super::*;
1017 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
1018
1019 #[test]
1020 fn test_into_mistral_basic_conversion() {
1021 let request = LanguageModelRequest {
1022 messages: vec![
1023 LanguageModelRequestMessage {
1024 role: Role::System,
1025 content: vec![MessageContent::Text("System prompt".into())],
1026 cache: false,
1027 },
1028 LanguageModelRequestMessage {
1029 role: Role::User,
1030 content: vec![MessageContent::Text("Hello".into())],
1031 cache: false,
1032 },
1033 ],
1034 temperature: Some(0.5),
1035 tools: vec![],
1036 tool_choice: None,
1037 thread_id: None,
1038 prompt_id: None,
1039 intent: None,
1040 mode: None,
1041 stop: vec![],
1042 thinking_allowed: true,
1043 };
1044
1045 let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
1046
1047 assert_eq!(mistral_request.model, "mistral-small-latest");
1048 assert_eq!(mistral_request.temperature, Some(0.5));
1049 assert_eq!(mistral_request.messages.len(), 2);
1050 assert!(mistral_request.stream);
1051 }
1052
1053 #[test]
1054 fn test_into_mistral_with_image() {
1055 let request = LanguageModelRequest {
1056 messages: vec![LanguageModelRequestMessage {
1057 role: Role::User,
1058 content: vec![
1059 MessageContent::Text("What's in this image?".into()),
1060 MessageContent::Image(LanguageModelImage {
1061 source: "base64data".into(),
1062 size: Default::default(),
1063 }),
1064 ],
1065 cache: false,
1066 }],
1067 tools: vec![],
1068 tool_choice: None,
1069 temperature: None,
1070 thread_id: None,
1071 prompt_id: None,
1072 intent: None,
1073 mode: None,
1074 stop: vec![],
1075 thinking_allowed: true,
1076 };
1077
1078 let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1079
1080 assert_eq!(mistral_request.messages.len(), 1);
1081 assert!(matches!(
1082 &mistral_request.messages[0],
1083 mistral::RequestMessage::User {
1084 content: mistral::MessageContent::Multipart { .. }
1085 }
1086 ));
1087
1088 if let mistral::RequestMessage::User {
1089 content: mistral::MessageContent::Multipart { content },
1090 } = &mistral_request.messages[0]
1091 {
1092 assert_eq!(content.len(), 2);
1093 assert!(matches!(
1094 &content[0],
1095 mistral::MessagePart::Text { text } if text == "What's in this image?"
1096 ));
1097 assert!(matches!(
1098 &content[1],
1099 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1100 ));
1101 }
1102 }
1103}