mistral.rs

   1use anyhow::{Result, anyhow};
   2use collections::BTreeMap;
   3use fs::Fs;
   4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
   5use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
   6use http_client::HttpClient;
   7use language_model::{
   8    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
   9    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  10    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  11    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  12    RateLimiter, Role, StopReason, TokenUsage,
  13};
  14use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
  15pub use settings::MistralAvailableModel as AvailableModel;
  16use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
  17use std::collections::HashMap;
  18use std::pin::Pin;
  19use std::str::FromStr;
  20use std::sync::{Arc, LazyLock};
  21use strum::IntoEnumIterator;
  22use ui::{List, prelude::*};
  23use ui_input::InputField;
  24use util::ResultExt;
  25use zed_env_vars::{EnvVar, env_var};
  26
  27use crate::ui::ConfiguredApiCard;
  28use crate::{api_key::ApiKeyState, ui::InstructionListItem};
  29
  30const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  31const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  32
  33const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
  34static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  35
  36const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
  37static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct MistralSettings {
  41    pub api_url: String,
  42    pub available_models: Vec<AvailableModel>,
  43}
  44
  45pub struct MistralLanguageModelProvider {
  46    http_client: Arc<dyn HttpClient>,
  47    state: Entity<State>,
  48}
  49
  50pub struct State {
  51    api_key_state: ApiKeyState,
  52    codestral_api_key_state: ApiKeyState,
  53}
  54
  55impl State {
  56    fn is_authenticated(&self) -> bool {
  57        self.api_key_state.has_key()
  58    }
  59
  60    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  61        let api_url = MistralLanguageModelProvider::api_url(cx);
  62        self.api_key_state
  63            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  64    }
  65
  66    fn set_codestral_api_key(
  67        &mut self,
  68        api_key: Option<String>,
  69        cx: &mut Context<Self>,
  70    ) -> Task<Result<()>> {
  71        self.codestral_api_key_state.store(
  72            CODESTRAL_API_URL.into(),
  73            api_key,
  74            |this| &mut this.codestral_api_key_state,
  75            cx,
  76        )
  77    }
  78
  79    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  80        let api_url = MistralLanguageModelProvider::api_url(cx);
  81        self.api_key_state.load_if_needed(
  82            api_url,
  83            &API_KEY_ENV_VAR,
  84            |this| &mut this.api_key_state,
  85            cx,
  86        )
  87    }
  88
  89    fn authenticate_codestral(
  90        &mut self,
  91        cx: &mut Context<Self>,
  92    ) -> Task<Result<(), AuthenticateError>> {
  93        self.codestral_api_key_state.load_if_needed(
  94            CODESTRAL_API_URL.into(),
  95            &CODESTRAL_API_KEY_ENV_VAR,
  96            |this| &mut this.codestral_api_key_state,
  97            cx,
  98        )
  99    }
 100}
 101
 102struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
 103
 104impl Global for GlobalMistralLanguageModelProvider {}
 105
 106impl MistralLanguageModelProvider {
 107    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
 108        cx.try_global::<GlobalMistralLanguageModelProvider>()
 109            .map(|this| &this.0)
 110    }
 111
 112    pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
 113        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
 114            return this.0.clone();
 115        }
 116        let state = cx.new(|cx| {
 117            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 118                let api_url = Self::api_url(cx);
 119                this.api_key_state.handle_url_change(
 120                    api_url,
 121                    &API_KEY_ENV_VAR,
 122                    |this| &mut this.api_key_state,
 123                    cx,
 124                );
 125                cx.notify();
 126            })
 127            .detach();
 128            State {
 129                api_key_state: ApiKeyState::new(Self::api_url(cx)),
 130                codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
 131            }
 132        });
 133
 134        let this = Arc::new(Self { http_client, state });
 135        cx.set_global(GlobalMistralLanguageModelProvider(this));
 136        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
 137    }
 138
 139    pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 140        self.state
 141            .update(cx, |state, cx| state.authenticate_codestral(cx))
 142    }
 143
 144    pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
 145        self.state.read(cx).codestral_api_key_state.key(url)
 146    }
 147
 148    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 149        Arc::new(MistralLanguageModel {
 150            id: LanguageModelId::from(model.id().to_string()),
 151            model,
 152            state: self.state.clone(),
 153            http_client: self.http_client.clone(),
 154            request_limiter: RateLimiter::new(4),
 155        })
 156    }
 157
 158    fn settings(cx: &App) -> &MistralSettings {
 159        &crate::AllLanguageModelSettings::get_global(cx).mistral
 160    }
 161
 162    fn api_url(cx: &App) -> SharedString {
 163        let api_url = &Self::settings(cx).api_url;
 164        if api_url.is_empty() {
 165            mistral::MISTRAL_API_URL.into()
 166        } else {
 167            SharedString::new(api_url.as_str())
 168        }
 169    }
 170}
 171
 172impl LanguageModelProviderState for MistralLanguageModelProvider {
 173    type ObservableEntity = State;
 174
 175    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 176        Some(self.state.clone())
 177    }
 178}
 179
 180impl LanguageModelProvider for MistralLanguageModelProvider {
 181    fn id(&self) -> LanguageModelProviderId {
 182        PROVIDER_ID
 183    }
 184
 185    fn name(&self) -> LanguageModelProviderName {
 186        PROVIDER_NAME
 187    }
 188
 189    fn icon(&self) -> IconName {
 190        IconName::AiMistral
 191    }
 192
 193    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 194        Some(self.create_language_model(mistral::Model::default()))
 195    }
 196
 197    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 198        Some(self.create_language_model(mistral::Model::default_fast()))
 199    }
 200
 201    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 202        let mut models = BTreeMap::default();
 203
 204        // Add base models from mistral::Model::iter()
 205        for model in mistral::Model::iter() {
 206            if !matches!(model, mistral::Model::Custom { .. }) {
 207                models.insert(model.id().to_string(), model);
 208            }
 209        }
 210
 211        // Override with available models from settings
 212        for model in &Self::settings(cx).available_models {
 213            models.insert(
 214                model.name.clone(),
 215                mistral::Model::Custom {
 216                    name: model.name.clone(),
 217                    display_name: model.display_name.clone(),
 218                    max_tokens: model.max_tokens,
 219                    max_output_tokens: model.max_output_tokens,
 220                    max_completion_tokens: model.max_completion_tokens,
 221                    supports_tools: model.supports_tools,
 222                    supports_images: model.supports_images,
 223                    supports_thinking: model.supports_thinking,
 224                },
 225            );
 226        }
 227
 228        models
 229            .into_values()
 230            .map(|model| {
 231                Arc::new(MistralLanguageModel {
 232                    id: LanguageModelId::from(model.id().to_string()),
 233                    model,
 234                    state: self.state.clone(),
 235                    http_client: self.http_client.clone(),
 236                    request_limiter: RateLimiter::new(4),
 237                }) as Arc<dyn LanguageModel>
 238            })
 239            .collect()
 240    }
 241
 242    fn is_authenticated(&self, cx: &App) -> bool {
 243        self.state.read(cx).is_authenticated()
 244    }
 245
 246    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 247        self.state.update(cx, |state, cx| state.authenticate(cx))
 248    }
 249
 250    fn configuration_view(
 251        &self,
 252        _target_agent: language_model::ConfigurationViewTargetAgent,
 253        window: &mut Window,
 254        cx: &mut App,
 255    ) -> AnyView {
 256        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 257            .into()
 258    }
 259
 260    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 261        self.state
 262            .update(cx, |state, cx| state.set_api_key(None, cx))
 263    }
 264}
 265
 266pub struct MistralLanguageModel {
 267    id: LanguageModelId,
 268    model: mistral::Model,
 269    state: Entity<State>,
 270    http_client: Arc<dyn HttpClient>,
 271    request_limiter: RateLimiter,
 272}
 273
 274impl MistralLanguageModel {
 275    fn stream_completion(
 276        &self,
 277        request: mistral::Request,
 278        cx: &AsyncApp,
 279    ) -> BoxFuture<
 280        'static,
 281        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 282    > {
 283        let http_client = self.http_client.clone();
 284
 285        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
 286            let api_url = MistralLanguageModelProvider::api_url(cx);
 287            (state.api_key_state.key(&api_url), api_url)
 288        }) else {
 289            return future::ready(Err(anyhow!("App state dropped"))).boxed();
 290        };
 291
 292        let future = self.request_limiter.stream(async move {
 293            let Some(api_key) = api_key else {
 294                return Err(LanguageModelCompletionError::NoApiKey {
 295                    provider: PROVIDER_NAME,
 296                });
 297            };
 298            let request =
 299                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 300            let response = request.await?;
 301            Ok(response)
 302        });
 303
 304        async move { Ok(future.await?.boxed()) }.boxed()
 305    }
 306}
 307
 308impl LanguageModel for MistralLanguageModel {
 309    fn id(&self) -> LanguageModelId {
 310        self.id.clone()
 311    }
 312
 313    fn name(&self) -> LanguageModelName {
 314        LanguageModelName::from(self.model.display_name().to_string())
 315    }
 316
 317    fn provider_id(&self) -> LanguageModelProviderId {
 318        PROVIDER_ID
 319    }
 320
 321    fn provider_name(&self) -> LanguageModelProviderName {
 322        PROVIDER_NAME
 323    }
 324
 325    fn supports_tools(&self) -> bool {
 326        self.model.supports_tools()
 327    }
 328
 329    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 330        self.model.supports_tools()
 331    }
 332
 333    fn supports_images(&self) -> bool {
 334        self.model.supports_images()
 335    }
 336
 337    fn telemetry_id(&self) -> String {
 338        format!("mistral/{}", self.model.id())
 339    }
 340
 341    fn max_token_count(&self) -> u64 {
 342        self.model.max_token_count()
 343    }
 344
 345    fn max_output_tokens(&self) -> Option<u64> {
 346        self.model.max_output_tokens()
 347    }
 348
 349    fn count_tokens(
 350        &self,
 351        request: LanguageModelRequest,
 352        cx: &App,
 353    ) -> BoxFuture<'static, Result<u64>> {
 354        cx.background_spawn(async move {
 355            let messages = request
 356                .messages
 357                .into_iter()
 358                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 359                    role: match message.role {
 360                        Role::User => "user".into(),
 361                        Role::Assistant => "assistant".into(),
 362                        Role::System => "system".into(),
 363                    },
 364                    content: Some(message.string_contents()),
 365                    name: None,
 366                    function_call: None,
 367                })
 368                .collect::<Vec<_>>();
 369
 370            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 371        })
 372        .boxed()
 373    }
 374
 375    fn stream_completion(
 376        &self,
 377        request: LanguageModelRequest,
 378        cx: &AsyncApp,
 379    ) -> BoxFuture<
 380        'static,
 381        Result<
 382            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 383            LanguageModelCompletionError,
 384        >,
 385    > {
 386        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
 387        let stream = self.stream_completion(request, cx);
 388
 389        async move {
 390            let stream = stream.await?;
 391            let mapper = MistralEventMapper::new();
 392            Ok(mapper.map_stream(stream).boxed())
 393        }
 394        .boxed()
 395    }
 396}
 397
 398pub fn into_mistral(
 399    request: LanguageModelRequest,
 400    model: mistral::Model,
 401    max_output_tokens: Option<u64>,
 402) -> mistral::Request {
 403    let stream = true;
 404
 405    let mut messages = Vec::new();
 406    for message in &request.messages {
 407        match message.role {
 408            Role::User => {
 409                let mut message_content = mistral::MessageContent::empty();
 410                for content in &message.content {
 411                    match content {
 412                        MessageContent::Text(text) => {
 413                            message_content
 414                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 415                        }
 416                        MessageContent::Image(image_content) => {
 417                            if model.supports_images() {
 418                                message_content.push_part(mistral::MessagePart::ImageUrl {
 419                                    image_url: image_content.to_base64_url(),
 420                                });
 421                            }
 422                        }
 423                        MessageContent::Thinking { text, .. } => {
 424                            if model.supports_thinking() {
 425                                message_content.push_part(mistral::MessagePart::Thinking {
 426                                    thinking: vec![mistral::ThinkingPart::Text {
 427                                        text: text.clone(),
 428                                    }],
 429                                });
 430                            }
 431                        }
 432                        MessageContent::RedactedThinking(_) => {}
 433                        MessageContent::ToolUse(_) => {
 434                            // Tool use is not supported in User messages for Mistral
 435                        }
 436                        MessageContent::ToolResult(tool_result) => {
 437                            let tool_content = match &tool_result.content {
 438                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 439                                LanguageModelToolResultContent::Image(_) => {
 440                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 441                                }
 442                            };
 443                            messages.push(mistral::RequestMessage::Tool {
 444                                content: tool_content,
 445                                tool_call_id: tool_result.tool_use_id.to_string(),
 446                            });
 447                        }
 448                    }
 449                }
 450                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 451                {
 452                    messages.push(mistral::RequestMessage::User {
 453                        content: message_content,
 454                    });
 455                }
 456            }
 457            Role::Assistant => {
 458                for content in &message.content {
 459                    match content {
 460                        MessageContent::Text(text) => {
 461                            messages.push(mistral::RequestMessage::Assistant {
 462                                content: Some(mistral::MessageContent::Plain {
 463                                    content: text.clone(),
 464                                }),
 465                                tool_calls: Vec::new(),
 466                            });
 467                        }
 468                        MessageContent::Thinking { text, .. } => {
 469                            if model.supports_thinking() {
 470                                messages.push(mistral::RequestMessage::Assistant {
 471                                    content: Some(mistral::MessageContent::Multipart {
 472                                        content: vec![mistral::MessagePart::Thinking {
 473                                            thinking: vec![mistral::ThinkingPart::Text {
 474                                                text: text.clone(),
 475                                            }],
 476                                        }],
 477                                    }),
 478                                    tool_calls: Vec::new(),
 479                                });
 480                            }
 481                        }
 482                        MessageContent::RedactedThinking(_) => {}
 483                        MessageContent::Image(_) => {}
 484                        MessageContent::ToolUse(tool_use) => {
 485                            let tool_call = mistral::ToolCall {
 486                                id: tool_use.id.to_string(),
 487                                content: mistral::ToolCallContent::Function {
 488                                    function: mistral::FunctionContent {
 489                                        name: tool_use.name.to_string(),
 490                                        arguments: serde_json::to_string(&tool_use.input)
 491                                            .unwrap_or_default(),
 492                                    },
 493                                },
 494                            };
 495
 496                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 497                                messages.last_mut()
 498                            {
 499                                tool_calls.push(tool_call);
 500                            } else {
 501                                messages.push(mistral::RequestMessage::Assistant {
 502                                    content: None,
 503                                    tool_calls: vec![tool_call],
 504                                });
 505                            }
 506                        }
 507                        MessageContent::ToolResult(_) => {
 508                            // Tool results are not supported in Assistant messages
 509                        }
 510                    }
 511                }
 512            }
 513            Role::System => {
 514                for content in &message.content {
 515                    match content {
 516                        MessageContent::Text(text) => {
 517                            messages.push(mistral::RequestMessage::System {
 518                                content: mistral::MessageContent::Plain {
 519                                    content: text.clone(),
 520                                },
 521                            });
 522                        }
 523                        MessageContent::Thinking { text, .. } => {
 524                            if model.supports_thinking() {
 525                                messages.push(mistral::RequestMessage::System {
 526                                    content: mistral::MessageContent::Multipart {
 527                                        content: vec![mistral::MessagePart::Thinking {
 528                                            thinking: vec![mistral::ThinkingPart::Text {
 529                                                text: text.clone(),
 530                                            }],
 531                                        }],
 532                                    },
 533                                });
 534                            }
 535                        }
 536                        MessageContent::RedactedThinking(_) => {}
 537                        MessageContent::Image(_)
 538                        | MessageContent::ToolUse(_)
 539                        | MessageContent::ToolResult(_) => {
 540                            // Images and tools are not supported in System messages
 541                        }
 542                    }
 543                }
 544            }
 545        }
 546    }
 547
 548    mistral::Request {
 549        model: model.id().to_string(),
 550        messages,
 551        stream,
 552        max_tokens: max_output_tokens,
 553        temperature: request.temperature,
 554        response_format: None,
 555        tool_choice: match request.tool_choice {
 556            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 557                Some(mistral::ToolChoice::Auto)
 558            }
 559            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 560                Some(mistral::ToolChoice::Any)
 561            }
 562            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 563            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 564            _ => None,
 565        },
 566        parallel_tool_calls: if !request.tools.is_empty() {
 567            Some(false)
 568        } else {
 569            None
 570        },
 571        tools: request
 572            .tools
 573            .into_iter()
 574            .map(|tool| mistral::ToolDefinition::Function {
 575                function: mistral::FunctionDefinition {
 576                    name: tool.name,
 577                    description: Some(tool.description),
 578                    parameters: Some(tool.input_schema),
 579                },
 580            })
 581            .collect(),
 582    }
 583}
 584
 585pub struct MistralEventMapper {
 586    tool_calls_by_index: HashMap<usize, RawToolCall>,
 587}
 588
 589impl MistralEventMapper {
 590    pub fn new() -> Self {
 591        Self {
 592            tool_calls_by_index: HashMap::default(),
 593        }
 594    }
 595
 596    pub fn map_stream(
 597        mut self,
 598        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 599    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 600    {
 601        events.flat_map(move |event| {
 602            futures::stream::iter(match event {
 603                Ok(event) => self.map_event(event),
 604                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 605            })
 606        })
 607    }
 608
 609    pub fn map_event(
 610        &mut self,
 611        event: mistral::StreamResponse,
 612    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 613        let Some(choice) = event.choices.first() else {
 614            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 615                "Response contained no choices"
 616            )))];
 617        };
 618
 619        let mut events = Vec::new();
 620        if let Some(content) = choice.delta.content.as_ref() {
 621            match content {
 622                mistral::MessageContentDelta::Text(text) => {
 623                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 624                }
 625                mistral::MessageContentDelta::Parts(parts) => {
 626                    for part in parts {
 627                        match part {
 628                            mistral::MessagePart::Text { text } => {
 629                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 630                            }
 631                            mistral::MessagePart::Thinking { thinking } => {
 632                                for tp in thinking.iter().cloned() {
 633                                    match tp {
 634                                        mistral::ThinkingPart::Text { text } => {
 635                                            events.push(Ok(
 636                                                LanguageModelCompletionEvent::Thinking {
 637                                                    text,
 638                                                    signature: None,
 639                                                },
 640                                            ));
 641                                        }
 642                                    }
 643                                }
 644                            }
 645                            mistral::MessagePart::ImageUrl { .. } => {
 646                                // We currently don't emit a separate event for images in responses.
 647                            }
 648                        }
 649                    }
 650                }
 651            }
 652        }
 653
 654        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 655            for tool_call in tool_calls {
 656                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 657
 658                if let Some(tool_id) = tool_call.id.clone() {
 659                    entry.id = tool_id;
 660                }
 661
 662                if let Some(function) = tool_call.function.as_ref() {
 663                    if let Some(name) = function.name.clone() {
 664                        entry.name = name;
 665                    }
 666
 667                    if let Some(arguments) = function.arguments.clone() {
 668                        entry.arguments.push_str(&arguments);
 669                    }
 670                }
 671            }
 672        }
 673
 674        if let Some(usage) = event.usage {
 675            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 676                input_tokens: usage.prompt_tokens,
 677                output_tokens: usage.completion_tokens,
 678                cache_creation_input_tokens: 0,
 679                cache_read_input_tokens: 0,
 680            })));
 681        }
 682
 683        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 684            match finish_reason {
 685                "stop" => {
 686                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 687                }
 688                "tool_calls" => {
 689                    events.extend(self.process_tool_calls());
 690                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 691                }
 692                unexpected => {
 693                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 694                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 695                }
 696            }
 697        }
 698
 699        events
 700    }
 701
 702    fn process_tool_calls(
 703        &mut self,
 704    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 705        let mut results = Vec::new();
 706
 707        for (_, tool_call) in self.tool_calls_by_index.drain() {
 708            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 709                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 710                    "Received incomplete tool call: missing id or name"
 711                ))));
 712                continue;
 713            }
 714
 715            match serde_json::Value::from_str(&tool_call.arguments) {
 716                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 717                    LanguageModelToolUse {
 718                        id: tool_call.id.into(),
 719                        name: tool_call.name.into(),
 720                        is_input_complete: true,
 721                        input,
 722                        raw_input: tool_call.arguments,
 723                    },
 724                ))),
 725                Err(error) => {
 726                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 727                        id: tool_call.id.into(),
 728                        tool_name: tool_call.name.into(),
 729                        raw_input: tool_call.arguments.into(),
 730                        json_parse_error: error.to_string(),
 731                    }))
 732                }
 733            }
 734        }
 735
 736        results
 737    }
 738}
 739
 740#[derive(Default)]
 741struct RawToolCall {
 742    id: String,
 743    name: String,
 744    arguments: String,
 745}
 746
 747struct ConfigurationView {
 748    api_key_editor: Entity<InputField>,
 749    codestral_api_key_editor: Entity<InputField>,
 750    state: Entity<State>,
 751    load_credentials_task: Option<Task<()>>,
 752}
 753
 754impl ConfigurationView {
 755    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 756        let api_key_editor =
 757            cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
 758        let codestral_api_key_editor =
 759            cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
 760
 761        cx.observe(&state, |_, _, cx| {
 762            cx.notify();
 763        })
 764        .detach();
 765
 766        let load_credentials_task = Some(cx.spawn_in(window, {
 767            let state = state.clone();
 768            async move |this, cx| {
 769                if let Some(task) = state
 770                    .update(cx, |state, cx| state.authenticate(cx))
 771                    .log_err()
 772                {
 773                    // We don't log an error, because "not signed in" is also an error.
 774                    let _ = task.await;
 775                }
 776                if let Some(task) = state
 777                    .update(cx, |state, cx| state.authenticate_codestral(cx))
 778                    .log_err()
 779                {
 780                    let _ = task.await;
 781                }
 782
 783                this.update(cx, |this, cx| {
 784                    this.load_credentials_task = None;
 785                    cx.notify();
 786                })
 787                .log_err();
 788            }
 789        }));
 790
 791        Self {
 792            api_key_editor,
 793            codestral_api_key_editor,
 794            state,
 795            load_credentials_task,
 796        }
 797    }
 798
 799    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 800        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 801        if api_key.is_empty() {
 802            return;
 803        }
 804
 805        // url changes can cause the editor to be displayed again
 806        self.api_key_editor
 807            .update(cx, |editor, cx| editor.set_text("", window, cx));
 808
 809        let state = self.state.clone();
 810        cx.spawn_in(window, async move |_, cx| {
 811            state
 812                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
 813                .await
 814        })
 815        .detach_and_log_err(cx);
 816    }
 817
 818    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 819        self.api_key_editor
 820            .update(cx, |editor, cx| editor.set_text("", window, cx));
 821
 822        let state = self.state.clone();
 823        cx.spawn_in(window, async move |_, cx| {
 824            state
 825                .update(cx, |state, cx| state.set_api_key(None, cx))?
 826                .await
 827        })
 828        .detach_and_log_err(cx);
 829    }
 830
 831    fn save_codestral_api_key(
 832        &mut self,
 833        _: &menu::Confirm,
 834        window: &mut Window,
 835        cx: &mut Context<Self>,
 836    ) {
 837        let api_key = self
 838            .codestral_api_key_editor
 839            .read(cx)
 840            .text(cx)
 841            .trim()
 842            .to_string();
 843        if api_key.is_empty() {
 844            return;
 845        }
 846
 847        // url changes can cause the editor to be displayed again
 848        self.codestral_api_key_editor
 849            .update(cx, |editor, cx| editor.set_text("", window, cx));
 850
 851        let state = self.state.clone();
 852        cx.spawn_in(window, async move |_, cx| {
 853            state
 854                .update(cx, |state, cx| {
 855                    state.set_codestral_api_key(Some(api_key), cx)
 856                })?
 857                .await?;
 858            cx.update(|_window, cx| {
 859                set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
 860            })
 861        })
 862        .detach_and_log_err(cx);
 863    }
 864
 865    fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 866        self.codestral_api_key_editor
 867            .update(cx, |editor, cx| editor.set_text("", window, cx));
 868
 869        let state = self.state.clone();
 870        cx.spawn_in(window, async move |_, cx| {
 871            state
 872                .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
 873                .await?;
 874            cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
 875        })
 876        .detach_and_log_err(cx);
 877    }
 878
 879    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
 880        !self.state.read(cx).is_authenticated()
 881    }
 882
 883    fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
 884        let key_state = &self.state.read(cx).codestral_api_key_state;
 885        let should_show_editor = !key_state.has_key();
 886        let env_var_set = key_state.is_from_env_var();
 887        let configured_card_label = if env_var_set {
 888            format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
 889        } else {
 890            "Codestral API key configured".to_string()
 891        };
 892
 893        if should_show_editor {
 894            v_flex()
 895                .id("codestral")
 896                .size_full()
 897                .mt_2()
 898                .on_action(cx.listener(Self::save_codestral_api_key))
 899                .child(Label::new(
 900                    "To use Codestral as an edit prediction provider, \
 901                    you need to add a Codestral-specific API key. Follow these steps:",
 902                ))
 903                .child(
 904                    List::new()
 905                        .child(InstructionListItem::new(
 906                            "Create one by visiting",
 907                            Some("the Codestral section of Mistral's console"),
 908                            Some("https://console.mistral.ai/codestral"),
 909                        ))
 910                        .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
 911                )
 912                .child(self.codestral_api_key_editor.clone())
 913                .child(
 914                    Label::new(
 915                        format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 916                    )
 917                    .size(LabelSize::Small).color(Color::Muted),
 918                ).into_any()
 919        } else {
 920            ConfiguredApiCard::new(configured_card_label)
 921                .disabled(env_var_set)
 922                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 923                .when(env_var_set, |this| {
 924                    this.tooltip_label(format!(
 925                        "To reset your API key, \
 926                            unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
 927                    ))
 928                })
 929                .on_click(
 930                    cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
 931                )
 932                .into_any_element()
 933        }
 934    }
 935}
 936
 937impl Render for ConfigurationView {
 938    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 939        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 940        let configured_card_label = if env_var_set {
 941            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
 942        } else {
 943            let api_url = MistralLanguageModelProvider::api_url(cx);
 944            if api_url == MISTRAL_API_URL {
 945                "API key configured".to_string()
 946            } else {
 947                format!("API key configured for {}", api_url)
 948            }
 949        };
 950
 951        if self.load_credentials_task.is_some() {
 952            div().child(Label::new("Loading credentials...")).into_any()
 953        } else if self.should_render_api_key_editor(cx) {
 954            v_flex()
 955                .size_full()
 956                .on_action(cx.listener(Self::save_api_key))
 957                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 958                .child(
 959                    List::new()
 960                        .child(InstructionListItem::new(
 961                            "Create one by visiting",
 962                            Some("Mistral's console"),
 963                            Some("https://console.mistral.ai/api-keys"),
 964                        ))
 965                        .child(InstructionListItem::text_only(
 966                            "Ensure your Mistral account has credits",
 967                        ))
 968                        .child(InstructionListItem::text_only(
 969                            "Paste your API key below and hit enter to start using the assistant",
 970                        )),
 971                )
 972                .child(self.api_key_editor.clone())
 973                .child(
 974                    Label::new(
 975                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 976                    )
 977                    .size(LabelSize::Small).color(Color::Muted),
 978                )
 979                .child(self.render_codestral_api_key_editor(cx))
 980                .into_any()
 981        } else {
 982            v_flex()
 983                .size_full()
 984                .gap_1()
 985                .child(
 986                    ConfiguredApiCard::new(configured_card_label)
 987                        .disabled(env_var_set)
 988                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 989                        .when(env_var_set, |this| {
 990                            this.tooltip_label(format!(
 991                                "To reset your API key, \
 992                                unset the {API_KEY_ENV_VAR_NAME} environment variable."
 993                            ))
 994                        }),
 995                )
 996                .child(self.render_codestral_api_key_editor(cx))
 997                .into_any()
 998        }
 999    }
1000}
1001
1002fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
1003    let fs = <dyn Fs>::global(cx);
1004    update_settings_file(fs, cx, move |settings, _| {
1005        settings
1006            .project
1007            .all_languages
1008            .features
1009            .get_or_insert_default()
1010            .edit_prediction_provider = Some(provider);
1011    });
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use super::*;
1017    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
1018
1019    #[test]
1020    fn test_into_mistral_basic_conversion() {
1021        let request = LanguageModelRequest {
1022            messages: vec![
1023                LanguageModelRequestMessage {
1024                    role: Role::System,
1025                    content: vec![MessageContent::Text("System prompt".into())],
1026                    cache: false,
1027                },
1028                LanguageModelRequestMessage {
1029                    role: Role::User,
1030                    content: vec![MessageContent::Text("Hello".into())],
1031                    cache: false,
1032                },
1033            ],
1034            temperature: Some(0.5),
1035            tools: vec![],
1036            tool_choice: None,
1037            thread_id: None,
1038            prompt_id: None,
1039            intent: None,
1040            mode: None,
1041            stop: vec![],
1042            thinking_allowed: true,
1043        };
1044
1045        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
1046
1047        assert_eq!(mistral_request.model, "mistral-small-latest");
1048        assert_eq!(mistral_request.temperature, Some(0.5));
1049        assert_eq!(mistral_request.messages.len(), 2);
1050        assert!(mistral_request.stream);
1051    }
1052
1053    #[test]
1054    fn test_into_mistral_with_image() {
1055        let request = LanguageModelRequest {
1056            messages: vec![LanguageModelRequestMessage {
1057                role: Role::User,
1058                content: vec![
1059                    MessageContent::Text("What's in this image?".into()),
1060                    MessageContent::Image(LanguageModelImage {
1061                        source: "base64data".into(),
1062                        size: Default::default(),
1063                    }),
1064                ],
1065                cache: false,
1066            }],
1067            tools: vec![],
1068            tool_choice: None,
1069            temperature: None,
1070            thread_id: None,
1071            prompt_id: None,
1072            intent: None,
1073            mode: None,
1074            stop: vec![],
1075            thinking_allowed: true,
1076        };
1077
1078        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1079
1080        assert_eq!(mistral_request.messages.len(), 1);
1081        assert!(matches!(
1082            &mistral_request.messages[0],
1083            mistral::RequestMessage::User {
1084                content: mistral::MessageContent::Multipart { .. }
1085            }
1086        ));
1087
1088        if let mistral::RequestMessage::User {
1089            content: mistral::MessageContent::Multipart { content },
1090        } = &mistral_request.messages[0]
1091        {
1092            assert_eq!(content.len(), 2);
1093            assert!(matches!(
1094                &content[0],
1095                mistral::MessagePart::Text { text } if text == "What's in this image?"
1096            ));
1097            assert!(matches!(
1098                &content[1],
1099                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1100            ));
1101        }
1102    }
1103}