1use anyhow::{Result, anyhow};
2use collections::BTreeMap;
3use fs::Fs;
4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
9 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
10 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
11 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
12 RateLimiter, Role, StopReason, TokenUsage,
13};
14use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
15pub use settings::MistralAvailableModel as AvailableModel;
16use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
17use std::collections::HashMap;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::sync::{Arc, LazyLock};
21use strum::IntoEnumIterator;
22use ui::{Icon, IconName, List, Tooltip, prelude::*};
23use ui_input::InputField;
24use util::{ResultExt, truncate_and_trailoff};
25use zed_env_vars::{EnvVar, env_var};
26
27use crate::{api_key::ApiKeyState, ui::InstructionListItem};
28
29const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
30const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
31
32const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
33static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
34
35const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
36static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
37
38#[derive(Default, Clone, Debug, PartialEq)]
39pub struct MistralSettings {
40 pub api_url: String,
41 pub available_models: Vec<AvailableModel>,
42}
43
44pub struct MistralLanguageModelProvider {
45 http_client: Arc<dyn HttpClient>,
46 state: Entity<State>,
47}
48
49pub struct State {
50 api_key_state: ApiKeyState,
51 codestral_api_key_state: ApiKeyState,
52}
53
54impl State {
55 fn is_authenticated(&self) -> bool {
56 self.api_key_state.has_key()
57 }
58
59 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
60 let api_url = MistralLanguageModelProvider::api_url(cx);
61 self.api_key_state
62 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
63 }
64
65 fn set_codestral_api_key(
66 &mut self,
67 api_key: Option<String>,
68 cx: &mut Context<Self>,
69 ) -> Task<Result<()>> {
70 self.codestral_api_key_state.store(
71 CODESTRAL_API_URL.into(),
72 api_key,
73 |this| &mut this.codestral_api_key_state,
74 cx,
75 )
76 }
77
78 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
79 let api_url = MistralLanguageModelProvider::api_url(cx);
80 self.api_key_state.load_if_needed(
81 api_url,
82 &API_KEY_ENV_VAR,
83 |this| &mut this.api_key_state,
84 cx,
85 )
86 }
87
88 fn authenticate_codestral(
89 &mut self,
90 cx: &mut Context<Self>,
91 ) -> Task<Result<(), AuthenticateError>> {
92 self.codestral_api_key_state.load_if_needed(
93 CODESTRAL_API_URL.into(),
94 &CODESTRAL_API_KEY_ENV_VAR,
95 |this| &mut this.codestral_api_key_state,
96 cx,
97 )
98 }
99}
100
101struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
102
103impl Global for GlobalMistralLanguageModelProvider {}
104
105impl MistralLanguageModelProvider {
106 pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
107 cx.try_global::<GlobalMistralLanguageModelProvider>()
108 .map(|this| &this.0)
109 }
110
111 pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
112 if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
113 return this.0.clone();
114 }
115 let state = cx.new(|cx| {
116 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
117 let api_url = Self::api_url(cx);
118 this.api_key_state.handle_url_change(
119 api_url,
120 &API_KEY_ENV_VAR,
121 |this| &mut this.api_key_state,
122 cx,
123 );
124 cx.notify();
125 })
126 .detach();
127 State {
128 api_key_state: ApiKeyState::new(Self::api_url(cx)),
129 codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
130 }
131 });
132
133 let this = Arc::new(Self { http_client, state });
134 cx.set_global(GlobalMistralLanguageModelProvider(this));
135 cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
136 }
137
138 pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
139 self.state
140 .update(cx, |state, cx| state.authenticate_codestral(cx))
141 }
142
143 pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
144 self.state.read(cx).codestral_api_key_state.key(url)
145 }
146
147 fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
148 Arc::new(MistralLanguageModel {
149 id: LanguageModelId::from(model.id().to_string()),
150 model,
151 state: self.state.clone(),
152 http_client: self.http_client.clone(),
153 request_limiter: RateLimiter::new(4),
154 })
155 }
156
157 fn settings(cx: &App) -> &MistralSettings {
158 &crate::AllLanguageModelSettings::get_global(cx).mistral
159 }
160
161 fn api_url(cx: &App) -> SharedString {
162 let api_url = &Self::settings(cx).api_url;
163 if api_url.is_empty() {
164 mistral::MISTRAL_API_URL.into()
165 } else {
166 SharedString::new(api_url.as_str())
167 }
168 }
169}
170
171impl LanguageModelProviderState for MistralLanguageModelProvider {
172 type ObservableEntity = State;
173
174 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
175 Some(self.state.clone())
176 }
177}
178
179impl LanguageModelProvider for MistralLanguageModelProvider {
180 fn id(&self) -> LanguageModelProviderId {
181 PROVIDER_ID
182 }
183
184 fn name(&self) -> LanguageModelProviderName {
185 PROVIDER_NAME
186 }
187
188 fn icon(&self) -> IconName {
189 IconName::AiMistral
190 }
191
192 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
193 Some(self.create_language_model(mistral::Model::default()))
194 }
195
196 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
197 Some(self.create_language_model(mistral::Model::default_fast()))
198 }
199
200 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
201 let mut models = BTreeMap::default();
202
203 // Add base models from mistral::Model::iter()
204 for model in mistral::Model::iter() {
205 if !matches!(model, mistral::Model::Custom { .. }) {
206 models.insert(model.id().to_string(), model);
207 }
208 }
209
210 // Override with available models from settings
211 for model in &Self::settings(cx).available_models {
212 models.insert(
213 model.name.clone(),
214 mistral::Model::Custom {
215 name: model.name.clone(),
216 display_name: model.display_name.clone(),
217 max_tokens: model.max_tokens,
218 max_output_tokens: model.max_output_tokens,
219 max_completion_tokens: model.max_completion_tokens,
220 supports_tools: model.supports_tools,
221 supports_images: model.supports_images,
222 supports_thinking: model.supports_thinking,
223 },
224 );
225 }
226
227 models
228 .into_values()
229 .map(|model| {
230 Arc::new(MistralLanguageModel {
231 id: LanguageModelId::from(model.id().to_string()),
232 model,
233 state: self.state.clone(),
234 http_client: self.http_client.clone(),
235 request_limiter: RateLimiter::new(4),
236 }) as Arc<dyn LanguageModel>
237 })
238 .collect()
239 }
240
241 fn is_authenticated(&self, cx: &App) -> bool {
242 self.state.read(cx).is_authenticated()
243 }
244
245 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
246 self.state.update(cx, |state, cx| state.authenticate(cx))
247 }
248
249 fn configuration_view(
250 &self,
251 _target_agent: language_model::ConfigurationViewTargetAgent,
252 window: &mut Window,
253 cx: &mut App,
254 ) -> AnyView {
255 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
256 .into()
257 }
258
259 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
260 self.state
261 .update(cx, |state, cx| state.set_api_key(None, cx))
262 }
263}
264
265pub struct MistralLanguageModel {
266 id: LanguageModelId,
267 model: mistral::Model,
268 state: Entity<State>,
269 http_client: Arc<dyn HttpClient>,
270 request_limiter: RateLimiter,
271}
272
273impl MistralLanguageModel {
274 fn stream_completion(
275 &self,
276 request: mistral::Request,
277 cx: &AsyncApp,
278 ) -> BoxFuture<
279 'static,
280 Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
281 > {
282 let http_client = self.http_client.clone();
283
284 let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
285 let api_url = MistralLanguageModelProvider::api_url(cx);
286 (state.api_key_state.key(&api_url), api_url)
287 }) else {
288 return future::ready(Err(anyhow!("App state dropped"))).boxed();
289 };
290
291 let future = self.request_limiter.stream(async move {
292 let Some(api_key) = api_key else {
293 return Err(LanguageModelCompletionError::NoApiKey {
294 provider: PROVIDER_NAME,
295 });
296 };
297 let request =
298 mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
299 let response = request.await?;
300 Ok(response)
301 });
302
303 async move { Ok(future.await?.boxed()) }.boxed()
304 }
305}
306
307impl LanguageModel for MistralLanguageModel {
308 fn id(&self) -> LanguageModelId {
309 self.id.clone()
310 }
311
312 fn name(&self) -> LanguageModelName {
313 LanguageModelName::from(self.model.display_name().to_string())
314 }
315
316 fn provider_id(&self) -> LanguageModelProviderId {
317 PROVIDER_ID
318 }
319
320 fn provider_name(&self) -> LanguageModelProviderName {
321 PROVIDER_NAME
322 }
323
324 fn supports_tools(&self) -> bool {
325 self.model.supports_tools()
326 }
327
328 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
329 self.model.supports_tools()
330 }
331
332 fn supports_images(&self) -> bool {
333 self.model.supports_images()
334 }
335
336 fn telemetry_id(&self) -> String {
337 format!("mistral/{}", self.model.id())
338 }
339
340 fn max_token_count(&self) -> u64 {
341 self.model.max_token_count()
342 }
343
344 fn max_output_tokens(&self) -> Option<u64> {
345 self.model.max_output_tokens()
346 }
347
348 fn count_tokens(
349 &self,
350 request: LanguageModelRequest,
351 cx: &App,
352 ) -> BoxFuture<'static, Result<u64>> {
353 cx.background_spawn(async move {
354 let messages = request
355 .messages
356 .into_iter()
357 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
358 role: match message.role {
359 Role::User => "user".into(),
360 Role::Assistant => "assistant".into(),
361 Role::System => "system".into(),
362 },
363 content: Some(message.string_contents()),
364 name: None,
365 function_call: None,
366 })
367 .collect::<Vec<_>>();
368
369 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
370 })
371 .boxed()
372 }
373
374 fn stream_completion(
375 &self,
376 request: LanguageModelRequest,
377 cx: &AsyncApp,
378 ) -> BoxFuture<
379 'static,
380 Result<
381 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
382 LanguageModelCompletionError,
383 >,
384 > {
385 let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
386 let stream = self.stream_completion(request, cx);
387
388 async move {
389 let stream = stream.await?;
390 let mapper = MistralEventMapper::new();
391 Ok(mapper.map_stream(stream).boxed())
392 }
393 .boxed()
394 }
395}
396
397pub fn into_mistral(
398 request: LanguageModelRequest,
399 model: mistral::Model,
400 max_output_tokens: Option<u64>,
401) -> mistral::Request {
402 let stream = true;
403
404 let mut messages = Vec::new();
405 for message in &request.messages {
406 match message.role {
407 Role::User => {
408 let mut message_content = mistral::MessageContent::empty();
409 for content in &message.content {
410 match content {
411 MessageContent::Text(text) => {
412 message_content
413 .push_part(mistral::MessagePart::Text { text: text.clone() });
414 }
415 MessageContent::Image(image_content) => {
416 if model.supports_images() {
417 message_content.push_part(mistral::MessagePart::ImageUrl {
418 image_url: image_content.to_base64_url(),
419 });
420 }
421 }
422 MessageContent::Thinking { text, .. } => {
423 if model.supports_thinking() {
424 message_content.push_part(mistral::MessagePart::Thinking {
425 thinking: vec![mistral::ThinkingPart::Text {
426 text: text.clone(),
427 }],
428 });
429 }
430 }
431 MessageContent::RedactedThinking(_) => {}
432 MessageContent::ToolUse(_) => {
433 // Tool use is not supported in User messages for Mistral
434 }
435 MessageContent::ToolResult(tool_result) => {
436 let tool_content = match &tool_result.content {
437 LanguageModelToolResultContent::Text(text) => text.to_string(),
438 LanguageModelToolResultContent::Image(_) => {
439 "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
440 }
441 };
442 messages.push(mistral::RequestMessage::Tool {
443 content: tool_content,
444 tool_call_id: tool_result.tool_use_id.to_string(),
445 });
446 }
447 }
448 }
449 if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
450 {
451 messages.push(mistral::RequestMessage::User {
452 content: message_content,
453 });
454 }
455 }
456 Role::Assistant => {
457 for content in &message.content {
458 match content {
459 MessageContent::Text(text) => {
460 messages.push(mistral::RequestMessage::Assistant {
461 content: Some(mistral::MessageContent::Plain {
462 content: text.clone(),
463 }),
464 tool_calls: Vec::new(),
465 });
466 }
467 MessageContent::Thinking { text, .. } => {
468 if model.supports_thinking() {
469 messages.push(mistral::RequestMessage::Assistant {
470 content: Some(mistral::MessageContent::Multipart {
471 content: vec![mistral::MessagePart::Thinking {
472 thinking: vec![mistral::ThinkingPart::Text {
473 text: text.clone(),
474 }],
475 }],
476 }),
477 tool_calls: Vec::new(),
478 });
479 }
480 }
481 MessageContent::RedactedThinking(_) => {}
482 MessageContent::Image(_) => {}
483 MessageContent::ToolUse(tool_use) => {
484 let tool_call = mistral::ToolCall {
485 id: tool_use.id.to_string(),
486 content: mistral::ToolCallContent::Function {
487 function: mistral::FunctionContent {
488 name: tool_use.name.to_string(),
489 arguments: serde_json::to_string(&tool_use.input)
490 .unwrap_or_default(),
491 },
492 },
493 };
494
495 if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
496 messages.last_mut()
497 {
498 tool_calls.push(tool_call);
499 } else {
500 messages.push(mistral::RequestMessage::Assistant {
501 content: None,
502 tool_calls: vec![tool_call],
503 });
504 }
505 }
506 MessageContent::ToolResult(_) => {
507 // Tool results are not supported in Assistant messages
508 }
509 }
510 }
511 }
512 Role::System => {
513 for content in &message.content {
514 match content {
515 MessageContent::Text(text) => {
516 messages.push(mistral::RequestMessage::System {
517 content: mistral::MessageContent::Plain {
518 content: text.clone(),
519 },
520 });
521 }
522 MessageContent::Thinking { text, .. } => {
523 if model.supports_thinking() {
524 messages.push(mistral::RequestMessage::System {
525 content: mistral::MessageContent::Multipart {
526 content: vec![mistral::MessagePart::Thinking {
527 thinking: vec![mistral::ThinkingPart::Text {
528 text: text.clone(),
529 }],
530 }],
531 },
532 });
533 }
534 }
535 MessageContent::RedactedThinking(_) => {}
536 MessageContent::Image(_)
537 | MessageContent::ToolUse(_)
538 | MessageContent::ToolResult(_) => {
539 // Images and tools are not supported in System messages
540 }
541 }
542 }
543 }
544 }
545 }
546
547 mistral::Request {
548 model: model.id().to_string(),
549 messages,
550 stream,
551 max_tokens: max_output_tokens,
552 temperature: request.temperature,
553 response_format: None,
554 tool_choice: match request.tool_choice {
555 Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
556 Some(mistral::ToolChoice::Auto)
557 }
558 Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
559 Some(mistral::ToolChoice::Any)
560 }
561 Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
562 _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
563 _ => None,
564 },
565 parallel_tool_calls: if !request.tools.is_empty() {
566 Some(false)
567 } else {
568 None
569 },
570 tools: request
571 .tools
572 .into_iter()
573 .map(|tool| mistral::ToolDefinition::Function {
574 function: mistral::FunctionDefinition {
575 name: tool.name,
576 description: Some(tool.description),
577 parameters: Some(tool.input_schema),
578 },
579 })
580 .collect(),
581 }
582}
583
584pub struct MistralEventMapper {
585 tool_calls_by_index: HashMap<usize, RawToolCall>,
586}
587
588impl MistralEventMapper {
589 pub fn new() -> Self {
590 Self {
591 tool_calls_by_index: HashMap::default(),
592 }
593 }
594
595 pub fn map_stream(
596 mut self,
597 events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
598 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
599 {
600 events.flat_map(move |event| {
601 futures::stream::iter(match event {
602 Ok(event) => self.map_event(event),
603 Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
604 })
605 })
606 }
607
608 pub fn map_event(
609 &mut self,
610 event: mistral::StreamResponse,
611 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
612 let Some(choice) = event.choices.first() else {
613 return vec![Err(LanguageModelCompletionError::from(anyhow!(
614 "Response contained no choices"
615 )))];
616 };
617
618 let mut events = Vec::new();
619 if let Some(content) = choice.delta.content.as_ref() {
620 match content {
621 mistral::MessageContentDelta::Text(text) => {
622 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
623 }
624 mistral::MessageContentDelta::Parts(parts) => {
625 for part in parts {
626 match part {
627 mistral::MessagePart::Text { text } => {
628 events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
629 }
630 mistral::MessagePart::Thinking { thinking } => {
631 for tp in thinking.iter().cloned() {
632 match tp {
633 mistral::ThinkingPart::Text { text } => {
634 events.push(Ok(
635 LanguageModelCompletionEvent::Thinking {
636 text,
637 signature: None,
638 },
639 ));
640 }
641 }
642 }
643 }
644 mistral::MessagePart::ImageUrl { .. } => {
645 // We currently don't emit a separate event for images in responses.
646 }
647 }
648 }
649 }
650 }
651 }
652
653 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
654 for tool_call in tool_calls {
655 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
656
657 if let Some(tool_id) = tool_call.id.clone() {
658 entry.id = tool_id;
659 }
660
661 if let Some(function) = tool_call.function.as_ref() {
662 if let Some(name) = function.name.clone() {
663 entry.name = name;
664 }
665
666 if let Some(arguments) = function.arguments.clone() {
667 entry.arguments.push_str(&arguments);
668 }
669 }
670 }
671 }
672
673 if let Some(usage) = event.usage {
674 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
675 input_tokens: usage.prompt_tokens,
676 output_tokens: usage.completion_tokens,
677 cache_creation_input_tokens: 0,
678 cache_read_input_tokens: 0,
679 })));
680 }
681
682 if let Some(finish_reason) = choice.finish_reason.as_deref() {
683 match finish_reason {
684 "stop" => {
685 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
686 }
687 "tool_calls" => {
688 events.extend(self.process_tool_calls());
689 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
690 }
691 unexpected => {
692 log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
693 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
694 }
695 }
696 }
697
698 events
699 }
700
701 fn process_tool_calls(
702 &mut self,
703 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
704 let mut results = Vec::new();
705
706 for (_, tool_call) in self.tool_calls_by_index.drain() {
707 if tool_call.id.is_empty() || tool_call.name.is_empty() {
708 results.push(Err(LanguageModelCompletionError::from(anyhow!(
709 "Received incomplete tool call: missing id or name"
710 ))));
711 continue;
712 }
713
714 match serde_json::Value::from_str(&tool_call.arguments) {
715 Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
716 LanguageModelToolUse {
717 id: tool_call.id.into(),
718 name: tool_call.name.into(),
719 is_input_complete: true,
720 input,
721 raw_input: tool_call.arguments,
722 },
723 ))),
724 Err(error) => {
725 results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
726 id: tool_call.id.into(),
727 tool_name: tool_call.name.into(),
728 raw_input: tool_call.arguments.into(),
729 json_parse_error: error.to_string(),
730 }))
731 }
732 }
733 }
734
735 results
736 }
737}
738
739#[derive(Default)]
740struct RawToolCall {
741 id: String,
742 name: String,
743 arguments: String,
744}
745
746struct ConfigurationView {
747 api_key_editor: Entity<InputField>,
748 codestral_api_key_editor: Entity<InputField>,
749 state: Entity<State>,
750 load_credentials_task: Option<Task<()>>,
751}
752
753impl ConfigurationView {
754 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
755 let api_key_editor =
756 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
757 let codestral_api_key_editor =
758 cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
759
760 cx.observe(&state, |_, _, cx| {
761 cx.notify();
762 })
763 .detach();
764
765 let load_credentials_task = Some(cx.spawn_in(window, {
766 let state = state.clone();
767 async move |this, cx| {
768 if let Some(task) = state
769 .update(cx, |state, cx| state.authenticate(cx))
770 .log_err()
771 {
772 // We don't log an error, because "not signed in" is also an error.
773 let _ = task.await;
774 }
775 if let Some(task) = state
776 .update(cx, |state, cx| state.authenticate_codestral(cx))
777 .log_err()
778 {
779 let _ = task.await;
780 }
781
782 this.update(cx, |this, cx| {
783 this.load_credentials_task = None;
784 cx.notify();
785 })
786 .log_err();
787 }
788 }));
789
790 Self {
791 api_key_editor,
792 codestral_api_key_editor,
793 state,
794 load_credentials_task,
795 }
796 }
797
798 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
799 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
800 if api_key.is_empty() {
801 return;
802 }
803
804 // url changes can cause the editor to be displayed again
805 self.api_key_editor
806 .update(cx, |editor, cx| editor.set_text("", window, cx));
807
808 let state = self.state.clone();
809 cx.spawn_in(window, async move |_, cx| {
810 state
811 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
812 .await
813 })
814 .detach_and_log_err(cx);
815 }
816
817 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
818 self.api_key_editor
819 .update(cx, |editor, cx| editor.set_text("", window, cx));
820
821 let state = self.state.clone();
822 cx.spawn_in(window, async move |_, cx| {
823 state
824 .update(cx, |state, cx| state.set_api_key(None, cx))?
825 .await
826 })
827 .detach_and_log_err(cx);
828 }
829
830 fn save_codestral_api_key(
831 &mut self,
832 _: &menu::Confirm,
833 window: &mut Window,
834 cx: &mut Context<Self>,
835 ) {
836 let api_key = self
837 .codestral_api_key_editor
838 .read(cx)
839 .text(cx)
840 .trim()
841 .to_string();
842 if api_key.is_empty() {
843 return;
844 }
845
846 // url changes can cause the editor to be displayed again
847 self.codestral_api_key_editor
848 .update(cx, |editor, cx| editor.set_text("", window, cx));
849
850 let state = self.state.clone();
851 cx.spawn_in(window, async move |_, cx| {
852 state
853 .update(cx, |state, cx| {
854 state.set_codestral_api_key(Some(api_key), cx)
855 })?
856 .await?;
857 cx.update(|_window, cx| {
858 set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
859 })
860 })
861 .detach_and_log_err(cx);
862 }
863
864 fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
865 self.codestral_api_key_editor
866 .update(cx, |editor, cx| editor.set_text("", window, cx));
867
868 let state = self.state.clone();
869 cx.spawn_in(window, async move |_, cx| {
870 state
871 .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
872 .await?;
873 cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
874 })
875 .detach_and_log_err(cx);
876 }
877
878 fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
879 !self.state.read(cx).is_authenticated()
880 }
881
882 fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
883 let key_state = &self.state.read(cx).codestral_api_key_state;
884 let should_show_editor = !key_state.has_key();
885 let env_var_set = key_state.is_from_env_var();
886 if should_show_editor {
887 v_flex()
888 .id("codestral")
889 .size_full()
890 .mt_2()
891 .on_action(cx.listener(Self::save_codestral_api_key))
892 .child(Label::new(
893 "To use Codestral as an edit prediction provider, \
894 you need to add a Codestral-specific API key. Follow these steps:",
895 ))
896 .child(
897 List::new()
898 .child(InstructionListItem::new(
899 "Create one by visiting",
900 Some("the Codestral section of Mistral's console"),
901 Some("https://console.mistral.ai/codestral"),
902 ))
903 .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
904 )
905 .child(self.codestral_api_key_editor.clone())
906 .child(
907 Label::new(
908 format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
909 )
910 .size(LabelSize::Small).color(Color::Muted),
911 ).into_any()
912 } else {
913 h_flex()
914 .id("codestral")
915 .mt_2()
916 .p_1()
917 .justify_between()
918 .rounded_md()
919 .border_1()
920 .border_color(cx.theme().colors().border)
921 .bg(cx.theme().colors().background)
922 .child(
923 h_flex()
924 .gap_1()
925 .child(Icon::new(IconName::Check).color(Color::Success))
926 .child(Label::new(if env_var_set {
927 format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
928 } else {
929 "Codestral API key configured".to_string()
930 })),
931 )
932 .child(
933 Button::new("reset-key", "Reset Key")
934 .label_size(LabelSize::Small)
935 .icon(Some(IconName::Trash))
936 .icon_size(IconSize::Small)
937 .icon_position(IconPosition::Start)
938 .disabled(env_var_set)
939 .when(env_var_set, |this| {
940 this.tooltip(Tooltip::text(format!(
941 "To reset your API key, \
942 unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
943 )))
944 })
945 .on_click(
946 cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
947 ),
948 ).into_any()
949 }
950 }
951}
952
953impl Render for ConfigurationView {
954 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
955 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
956
957 if self.load_credentials_task.is_some() {
958 div().child(Label::new("Loading credentials...")).into_any()
959 } else if self.should_render_api_key_editor(cx) {
960 v_flex()
961 .size_full()
962 .on_action(cx.listener(Self::save_api_key))
963 .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
964 .child(
965 List::new()
966 .child(InstructionListItem::new(
967 "Create one by visiting",
968 Some("Mistral's console"),
969 Some("https://console.mistral.ai/api-keys"),
970 ))
971 .child(InstructionListItem::text_only(
972 "Ensure your Mistral account has credits",
973 ))
974 .child(InstructionListItem::text_only(
975 "Paste your API key below and hit enter to start using the assistant",
976 )),
977 )
978 .child(self.api_key_editor.clone())
979 .child(
980 Label::new(
981 format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
982 )
983 .size(LabelSize::Small).color(Color::Muted),
984 )
985 .child(self.render_codestral_api_key_editor(cx))
986 .into_any()
987 } else {
988 v_flex()
989 .size_full()
990 .child(
991 h_flex()
992 .mt_1()
993 .p_1()
994 .justify_between()
995 .rounded_md()
996 .border_1()
997 .border_color(cx.theme().colors().border)
998 .bg(cx.theme().colors().background)
999 .child(
1000 h_flex()
1001 .gap_1()
1002 .child(Icon::new(IconName::Check).color(Color::Success))
1003 .child(Label::new(if env_var_set {
1004 format!(
1005 "API key set in {API_KEY_ENV_VAR_NAME} environment variable"
1006 )
1007 } else {
1008 let api_url = MistralLanguageModelProvider::api_url(cx);
1009 if api_url == MISTRAL_API_URL {
1010 "API key configured".to_string()
1011 } else {
1012 format!(
1013 "API key configured for {}",
1014 truncate_and_trailoff(&api_url, 32)
1015 )
1016 }
1017 })),
1018 )
1019 .child(
1020 Button::new("reset-key", "Reset Key")
1021 .label_size(LabelSize::Small)
1022 .icon(Some(IconName::Trash))
1023 .icon_size(IconSize::Small)
1024 .icon_position(IconPosition::Start)
1025 .disabled(env_var_set)
1026 .when(env_var_set, |this| {
1027 this.tooltip(Tooltip::text(format!(
1028 "To reset your API key, \
1029 unset the {API_KEY_ENV_VAR_NAME} environment variable."
1030 )))
1031 })
1032 .on_click(cx.listener(|this, _, window, cx| {
1033 this.reset_api_key(window, cx)
1034 })),
1035 ),
1036 )
1037 .child(self.render_codestral_api_key_editor(cx))
1038 .into_any()
1039 }
1040 }
1041}
1042
1043fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
1044 let fs = <dyn Fs>::global(cx);
1045 update_settings_file(fs, cx, move |settings, _| {
1046 settings
1047 .project
1048 .all_languages
1049 .features
1050 .get_or_insert_default()
1051 .edit_prediction_provider = Some(provider);
1052 });
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use super::*;
1058 use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
1059
1060 #[test]
1061 fn test_into_mistral_basic_conversion() {
1062 let request = LanguageModelRequest {
1063 messages: vec![
1064 LanguageModelRequestMessage {
1065 role: Role::System,
1066 content: vec![MessageContent::Text("System prompt".into())],
1067 cache: false,
1068 },
1069 LanguageModelRequestMessage {
1070 role: Role::User,
1071 content: vec![MessageContent::Text("Hello".into())],
1072 cache: false,
1073 },
1074 ],
1075 temperature: Some(0.5),
1076 tools: vec![],
1077 tool_choice: None,
1078 thread_id: None,
1079 prompt_id: None,
1080 intent: None,
1081 mode: None,
1082 stop: vec![],
1083 thinking_allowed: true,
1084 };
1085
1086 let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
1087
1088 assert_eq!(mistral_request.model, "mistral-small-latest");
1089 assert_eq!(mistral_request.temperature, Some(0.5));
1090 assert_eq!(mistral_request.messages.len(), 2);
1091 assert!(mistral_request.stream);
1092 }
1093
1094 #[test]
1095 fn test_into_mistral_with_image() {
1096 let request = LanguageModelRequest {
1097 messages: vec![LanguageModelRequestMessage {
1098 role: Role::User,
1099 content: vec![
1100 MessageContent::Text("What's in this image?".into()),
1101 MessageContent::Image(LanguageModelImage {
1102 source: "base64data".into(),
1103 size: Default::default(),
1104 }),
1105 ],
1106 cache: false,
1107 }],
1108 tools: vec![],
1109 tool_choice: None,
1110 temperature: None,
1111 thread_id: None,
1112 prompt_id: None,
1113 intent: None,
1114 mode: None,
1115 stop: vec![],
1116 thinking_allowed: true,
1117 };
1118
1119 let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1120
1121 assert_eq!(mistral_request.messages.len(), 1);
1122 assert!(matches!(
1123 &mistral_request.messages[0],
1124 mistral::RequestMessage::User {
1125 content: mistral::MessageContent::Multipart { .. }
1126 }
1127 ));
1128
1129 if let mistral::RequestMessage::User {
1130 content: mistral::MessageContent::Multipart { content },
1131 } = &mistral_request.messages[0]
1132 {
1133 assert_eq!(content.len(), 2);
1134 assert!(matches!(
1135 &content[0],
1136 mistral::MessagePart::Text { text } if text == "What's in this image?"
1137 ));
1138 assert!(matches!(
1139 &content[1],
1140 mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1141 ));
1142 }
1143 }
1144}