mistral.rs

   1use anyhow::{Result, anyhow};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4
   5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
   6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
   7use http_client::HttpClient;
   8use language_model::{
   9    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  10    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  12    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
  13    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
  14};
  15pub use mistral::{MISTRAL_API_URL, StreamResponse};
  16pub use settings::MistralAvailableModel as AvailableModel;
  17use settings::{Settings, SettingsStore};
  18use std::collections::HashMap;
  19use std::pin::Pin;
  20use std::sync::{Arc, LazyLock};
  21use strum::IntoEnumIterator;
  22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  23use ui_input::InputField;
  24use util::ResultExt;
  25
  26use language_model::util::{fix_streamed_json, parse_tool_arguments};
  27
  28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  30
  31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
  32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  33
  34#[derive(Default, Clone, Debug, PartialEq)]
  35pub struct MistralSettings {
  36    pub api_url: String,
  37    pub available_models: Vec<AvailableModel>,
  38}
  39
  40pub struct MistralLanguageModelProvider {
  41    http_client: Arc<dyn HttpClient>,
  42    pub state: Entity<State>,
  43}
  44
  45pub struct State {
  46    api_key_state: ApiKeyState,
  47    credentials_provider: Arc<dyn CredentialsProvider>,
  48}
  49
  50impl State {
  51    fn is_authenticated(&self) -> bool {
  52        self.api_key_state.has_key()
  53    }
  54
  55    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  56        let credentials_provider = self.credentials_provider.clone();
  57        let api_url = MistralLanguageModelProvider::api_url(cx);
  58        self.api_key_state.store(
  59            api_url,
  60            api_key,
  61            |this| &mut this.api_key_state,
  62            credentials_provider,
  63            cx,
  64        )
  65    }
  66
  67    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  68        let credentials_provider = self.credentials_provider.clone();
  69        let api_url = MistralLanguageModelProvider::api_url(cx);
  70        self.api_key_state.load_if_needed(
  71            api_url,
  72            |this| &mut this.api_key_state,
  73            credentials_provider,
  74            cx,
  75        )
  76    }
  77}
  78
  79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
  80
  81impl Global for GlobalMistralLanguageModelProvider {}
  82
  83impl MistralLanguageModelProvider {
  84    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
  85        cx.try_global::<GlobalMistralLanguageModelProvider>()
  86            .map(|this| &this.0)
  87    }
  88
  89    pub fn global(
  90        http_client: Arc<dyn HttpClient>,
  91        credentials_provider: Arc<dyn CredentialsProvider>,
  92        cx: &mut App,
  93    ) -> Arc<Self> {
  94        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
  95            return this.0.clone();
  96        }
  97        let state = cx.new(|cx| {
  98            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  99                let credentials_provider = this.credentials_provider.clone();
 100                let api_url = Self::api_url(cx);
 101                this.api_key_state.handle_url_change(
 102                    api_url,
 103                    |this| &mut this.api_key_state,
 104                    credentials_provider,
 105                    cx,
 106                );
 107                cx.notify();
 108            })
 109            .detach();
 110            State {
 111                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 112                credentials_provider,
 113            }
 114        });
 115
 116        let this = Arc::new(Self { http_client, state });
 117        cx.set_global(GlobalMistralLanguageModelProvider(this));
 118        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
 119    }
 120
 121    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 122        Arc::new(MistralLanguageModel {
 123            id: LanguageModelId::from(model.id().to_string()),
 124            model,
 125            state: self.state.clone(),
 126            http_client: self.http_client.clone(),
 127            request_limiter: RateLimiter::new(4),
 128        })
 129    }
 130
 131    fn settings(cx: &App) -> &MistralSettings {
 132        &crate::AllLanguageModelSettings::get_global(cx).mistral
 133    }
 134
 135    pub fn api_url(cx: &App) -> SharedString {
 136        let api_url = &Self::settings(cx).api_url;
 137        if api_url.is_empty() {
 138            mistral::MISTRAL_API_URL.into()
 139        } else {
 140            SharedString::new(api_url.as_str())
 141        }
 142    }
 143}
 144
 145impl LanguageModelProviderState for MistralLanguageModelProvider {
 146    type ObservableEntity = State;
 147
 148    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 149        Some(self.state.clone())
 150    }
 151}
 152
 153impl LanguageModelProvider for MistralLanguageModelProvider {
 154    fn id(&self) -> LanguageModelProviderId {
 155        PROVIDER_ID
 156    }
 157
 158    fn name(&self) -> LanguageModelProviderName {
 159        PROVIDER_NAME
 160    }
 161
 162    fn icon(&self) -> IconOrSvg {
 163        IconOrSvg::Icon(IconName::AiMistral)
 164    }
 165
 166    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 167        Some(self.create_language_model(mistral::Model::default()))
 168    }
 169
 170    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 171        Some(self.create_language_model(mistral::Model::default_fast()))
 172    }
 173
 174    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 175        let mut models = BTreeMap::default();
 176
 177        // Add base models from mistral::Model::iter()
 178        for model in mistral::Model::iter() {
 179            if !matches!(model, mistral::Model::Custom { .. }) {
 180                models.insert(model.id().to_string(), model);
 181            }
 182        }
 183
 184        // Override with available models from settings
 185        for model in &Self::settings(cx).available_models {
 186            models.insert(
 187                model.name.clone(),
 188                mistral::Model::Custom {
 189                    name: model.name.clone(),
 190                    display_name: model.display_name.clone(),
 191                    max_tokens: model.max_tokens,
 192                    max_output_tokens: model.max_output_tokens,
 193                    max_completion_tokens: model.max_completion_tokens,
 194                    supports_tools: model.supports_tools,
 195                    supports_images: model.supports_images,
 196                    supports_thinking: model.supports_thinking,
 197                },
 198            );
 199        }
 200
 201        models
 202            .into_values()
 203            .map(|model| {
 204                Arc::new(MistralLanguageModel {
 205                    id: LanguageModelId::from(model.id().to_string()),
 206                    model,
 207                    state: self.state.clone(),
 208                    http_client: self.http_client.clone(),
 209                    request_limiter: RateLimiter::new(4),
 210                }) as Arc<dyn LanguageModel>
 211            })
 212            .collect()
 213    }
 214
 215    fn is_authenticated(&self, cx: &App) -> bool {
 216        self.state.read(cx).is_authenticated()
 217    }
 218
 219    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 220        self.state.update(cx, |state, cx| state.authenticate(cx))
 221    }
 222
 223    fn configuration_view(
 224        &self,
 225        _target_agent: language_model::ConfigurationViewTargetAgent,
 226        window: &mut Window,
 227        cx: &mut App,
 228    ) -> AnyView {
 229        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 230            .into()
 231    }
 232
 233    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 234        self.state
 235            .update(cx, |state, cx| state.set_api_key(None, cx))
 236    }
 237}
 238
 239pub struct MistralLanguageModel {
 240    id: LanguageModelId,
 241    model: mistral::Model,
 242    state: Entity<State>,
 243    http_client: Arc<dyn HttpClient>,
 244    request_limiter: RateLimiter,
 245}
 246
 247impl MistralLanguageModel {
 248    fn stream_completion(
 249        &self,
 250        request: mistral::Request,
 251        affinity: Option<String>,
 252        cx: &AsyncApp,
 253    ) -> BoxFuture<
 254        'static,
 255        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 256    > {
 257        let http_client = self.http_client.clone();
 258
 259        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 260            let api_url = MistralLanguageModelProvider::api_url(cx);
 261            (state.api_key_state.key(&api_url), api_url)
 262        });
 263
 264        let future = self.request_limiter.stream(async move {
 265            let Some(api_key) = api_key else {
 266                return Err(LanguageModelCompletionError::NoApiKey {
 267                    provider: PROVIDER_NAME,
 268                });
 269            };
 270            let request = mistral::stream_completion(
 271                http_client.as_ref(),
 272                &api_url,
 273                &api_key,
 274                request,
 275                affinity,
 276            );
 277            let response = request.await?;
 278            Ok(response)
 279        });
 280
 281        async move { Ok(future.await?.boxed()) }.boxed()
 282    }
 283}
 284
 285impl LanguageModel for MistralLanguageModel {
 286    fn id(&self) -> LanguageModelId {
 287        self.id.clone()
 288    }
 289
 290    fn name(&self) -> LanguageModelName {
 291        LanguageModelName::from(self.model.display_name().to_string())
 292    }
 293
 294    fn provider_id(&self) -> LanguageModelProviderId {
 295        PROVIDER_ID
 296    }
 297
 298    fn provider_name(&self) -> LanguageModelProviderName {
 299        PROVIDER_NAME
 300    }
 301
 302    fn supports_tools(&self) -> bool {
 303        self.model.supports_tools()
 304    }
 305
 306    fn supports_streaming_tools(&self) -> bool {
 307        true
 308    }
 309
 310    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 311        self.model.supports_tools()
 312    }
 313
 314    fn supports_images(&self) -> bool {
 315        self.model.supports_images()
 316    }
 317
 318    fn telemetry_id(&self) -> String {
 319        format!("mistral/{}", self.model.id())
 320    }
 321
 322    fn max_token_count(&self) -> u64 {
 323        self.model.max_token_count()
 324    }
 325
 326    fn max_output_tokens(&self) -> Option<u64> {
 327        self.model.max_output_tokens()
 328    }
 329
 330    fn count_tokens(
 331        &self,
 332        request: LanguageModelRequest,
 333        cx: &App,
 334    ) -> BoxFuture<'static, Result<u64>> {
 335        cx.background_spawn(async move {
 336            let messages = request
 337                .messages
 338                .into_iter()
 339                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 340                    role: match message.role {
 341                        Role::User => "user".into(),
 342                        Role::Assistant => "assistant".into(),
 343                        Role::System => "system".into(),
 344                    },
 345                    content: Some(message.string_contents()),
 346                    name: None,
 347                    function_call: None,
 348                })
 349                .collect::<Vec<_>>();
 350
 351            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 352        })
 353        .boxed()
 354    }
 355
 356    fn stream_completion(
 357        &self,
 358        request: LanguageModelRequest,
 359        cx: &AsyncApp,
 360    ) -> BoxFuture<
 361        'static,
 362        Result<
 363            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 364            LanguageModelCompletionError,
 365        >,
 366    > {
 367        let (request, affinity) =
 368            into_mistral(request, self.model.clone(), self.max_output_tokens());
 369        let stream = self.stream_completion(request, affinity, cx);
 370
 371        async move {
 372            let stream = stream.await?;
 373            let mapper = MistralEventMapper::new();
 374            Ok(mapper.map_stream(stream).boxed())
 375        }
 376        .boxed()
 377    }
 378}
 379
 380pub fn into_mistral(
 381    request: LanguageModelRequest,
 382    model: mistral::Model,
 383    max_output_tokens: Option<u64>,
 384) -> (mistral::Request, Option<String>) {
 385    let stream = true;
 386
 387    let mut messages = Vec::new();
 388    for message in &request.messages {
 389        match message.role {
 390            Role::User => {
 391                let mut message_content = mistral::MessageContent::empty();
 392                for content in &message.content {
 393                    match content {
 394                        MessageContent::Text(text) => {
 395                            message_content
 396                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 397                        }
 398                        MessageContent::Image(image_content) => {
 399                            if model.supports_images() {
 400                                message_content.push_part(mistral::MessagePart::ImageUrl {
 401                                    image_url: image_content.to_base64_url(),
 402                                });
 403                            }
 404                        }
 405                        MessageContent::Thinking { text, .. } => {
 406                            if model.supports_thinking() {
 407                                message_content.push_part(mistral::MessagePart::Thinking {
 408                                    thinking: vec![mistral::ThinkingPart::Text {
 409                                        text: text.clone(),
 410                                    }],
 411                                });
 412                            }
 413                        }
 414                        MessageContent::RedactedThinking(_) => {}
 415                        MessageContent::ToolUse(_) => {
 416                            // Tool use is not supported in User messages for Mistral
 417                        }
 418                        MessageContent::ToolResult(tool_result) => {
 419                            let tool_content = match &tool_result.content {
 420                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 421                                LanguageModelToolResultContent::Image(_) => {
 422                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 423                                }
 424                            };
 425                            messages.push(mistral::RequestMessage::Tool {
 426                                content: tool_content,
 427                                tool_call_id: tool_result.tool_use_id.to_string(),
 428                            });
 429                        }
 430                    }
 431                }
 432                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 433                {
 434                    messages.push(mistral::RequestMessage::User {
 435                        content: message_content,
 436                    });
 437                }
 438            }
 439            Role::Assistant => {
 440                for content in &message.content {
 441                    match content {
 442                        MessageContent::Text(text) if text.is_empty() => {
 443                            // Mistral API returns a 400 if there's neither content nor tool_calls
 444                        }
 445                        MessageContent::Text(text) => {
 446                            messages.push(mistral::RequestMessage::Assistant {
 447                                content: Some(mistral::MessageContent::Plain {
 448                                    content: text.clone(),
 449                                }),
 450                                tool_calls: Vec::new(),
 451                            });
 452                        }
 453                        MessageContent::Thinking { text, .. } => {
 454                            if model.supports_thinking() {
 455                                messages.push(mistral::RequestMessage::Assistant {
 456                                    content: Some(mistral::MessageContent::Multipart {
 457                                        content: vec![mistral::MessagePart::Thinking {
 458                                            thinking: vec![mistral::ThinkingPart::Text {
 459                                                text: text.clone(),
 460                                            }],
 461                                        }],
 462                                    }),
 463                                    tool_calls: Vec::new(),
 464                                });
 465                            }
 466                        }
 467                        MessageContent::RedactedThinking(_) => {}
 468                        MessageContent::Image(_) => {}
 469                        MessageContent::ToolUse(tool_use) => {
 470                            let tool_call = mistral::ToolCall {
 471                                id: tool_use.id.to_string(),
 472                                content: mistral::ToolCallContent::Function {
 473                                    function: mistral::FunctionContent {
 474                                        name: tool_use.name.to_string(),
 475                                        arguments: serde_json::to_string(&tool_use.input)
 476                                            .unwrap_or_default(),
 477                                    },
 478                                },
 479                            };
 480
 481                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 482                                messages.last_mut()
 483                            {
 484                                tool_calls.push(tool_call);
 485                            } else {
 486                                messages.push(mistral::RequestMessage::Assistant {
 487                                    content: None,
 488                                    tool_calls: vec![tool_call],
 489                                });
 490                            }
 491                        }
 492                        MessageContent::ToolResult(_) => {
 493                            // Tool results are not supported in Assistant messages
 494                        }
 495                    }
 496                }
 497            }
 498            Role::System => {
 499                for content in &message.content {
 500                    match content {
 501                        MessageContent::Text(text) => {
 502                            messages.push(mistral::RequestMessage::System {
 503                                content: mistral::MessageContent::Plain {
 504                                    content: text.clone(),
 505                                },
 506                            });
 507                        }
 508                        MessageContent::Thinking { text, .. } => {
 509                            if model.supports_thinking() {
 510                                messages.push(mistral::RequestMessage::System {
 511                                    content: mistral::MessageContent::Multipart {
 512                                        content: vec![mistral::MessagePart::Thinking {
 513                                            thinking: vec![mistral::ThinkingPart::Text {
 514                                                text: text.clone(),
 515                                            }],
 516                                        }],
 517                                    },
 518                                });
 519                            }
 520                        }
 521                        MessageContent::RedactedThinking(_) => {}
 522                        MessageContent::Image(_)
 523                        | MessageContent::ToolUse(_)
 524                        | MessageContent::ToolResult(_) => {
 525                            // Images and tools are not supported in System messages
 526                        }
 527                    }
 528                }
 529            }
 530        }
 531    }
 532
 533    (
 534        mistral::Request {
 535            model: model.id().to_string(),
 536            messages,
 537            stream,
 538            stream_options: if stream {
 539                Some(mistral::StreamOptions {
 540                    stream_tool_calls: Some(true),
 541                })
 542            } else {
 543                None
 544            },
 545            max_tokens: max_output_tokens,
 546            temperature: request.temperature,
 547            response_format: None,
 548            tool_choice: match request.tool_choice {
 549                Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 550                    Some(mistral::ToolChoice::Auto)
 551                }
 552                Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 553                    Some(mistral::ToolChoice::Any)
 554                }
 555                Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 556                _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 557                _ => None,
 558            },
 559            parallel_tool_calls: if !request.tools.is_empty() {
 560                Some(false)
 561            } else {
 562                None
 563            },
 564            tools: request
 565                .tools
 566                .into_iter()
 567                .map(|tool| mistral::ToolDefinition::Function {
 568                    function: mistral::FunctionDefinition {
 569                        name: tool.name,
 570                        description: Some(tool.description),
 571                        parameters: Some(tool.input_schema),
 572                    },
 573                })
 574                .collect(),
 575        },
 576        request.thread_id,
 577    )
 578}
 579
 580pub struct MistralEventMapper {
 581    tool_calls_by_index: HashMap<usize, RawToolCall>,
 582}
 583
 584impl MistralEventMapper {
 585    pub fn new() -> Self {
 586        Self {
 587            tool_calls_by_index: HashMap::default(),
 588        }
 589    }
 590
 591    pub fn map_stream(
 592        mut self,
 593        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 594    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 595    {
 596        events.flat_map(move |event| {
 597            futures::stream::iter(match event {
 598                Ok(event) => self.map_event(event),
 599                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 600            })
 601        })
 602    }
 603
 604    pub fn map_event(
 605        &mut self,
 606        event: mistral::StreamResponse,
 607    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 608        let Some(choice) = event.choices.first() else {
 609            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 610                "Response contained no choices"
 611            )))];
 612        };
 613
 614        let mut events = Vec::new();
 615        if let Some(content) = choice.delta.content.as_ref() {
 616            match content {
 617                mistral::MessageContentDelta::Text(text) => {
 618                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 619                }
 620                mistral::MessageContentDelta::Parts(parts) => {
 621                    for part in parts {
 622                        match part {
 623                            mistral::MessagePart::Text { text } => {
 624                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 625                            }
 626                            mistral::MessagePart::Thinking { thinking } => {
 627                                for tp in thinking.iter().cloned() {
 628                                    match tp {
 629                                        mistral::ThinkingPart::Text { text } => {
 630                                            events.push(Ok(
 631                                                LanguageModelCompletionEvent::Thinking {
 632                                                    text,
 633                                                    signature: None,
 634                                                },
 635                                            ));
 636                                        }
 637                                    }
 638                                }
 639                            }
 640                            mistral::MessagePart::ImageUrl { .. } => {
 641                                // We currently don't emit a separate event for images in responses.
 642                            }
 643                        }
 644                    }
 645                }
 646            }
 647        }
 648
 649        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 650            for tool_call in tool_calls {
 651                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 652
 653                if let Some(tool_id) = tool_call.id.clone()
 654                    && !tool_id.is_empty()
 655                {
 656                    entry.id = tool_id;
 657                }
 658
 659                if let Some(function) = tool_call.function.as_ref() {
 660                    if let Some(name) = function.name.clone()
 661                        && !name.is_empty()
 662                    {
 663                        entry.name = name;
 664                    }
 665
 666                    if let Some(arguments) = function.arguments.clone() {
 667                        entry.arguments.push_str(&arguments);
 668                    }
 669                }
 670
 671                if !entry.id.is_empty() && !entry.name.is_empty() {
 672                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
 673                        &fix_streamed_json(&entry.arguments),
 674                    ) {
 675                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 676                            LanguageModelToolUse {
 677                                id: entry.id.clone().into(),
 678                                name: entry.name.as_str().into(),
 679                                is_input_complete: false,
 680                                input,
 681                                raw_input: entry.arguments.clone(),
 682                                thought_signature: None,
 683                            },
 684                        )));
 685                    }
 686                }
 687            }
 688        }
 689
 690        if let Some(usage) = event.usage {
 691            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 692                input_tokens: usage.prompt_tokens,
 693                output_tokens: usage.completion_tokens,
 694                cache_creation_input_tokens: 0,
 695                cache_read_input_tokens: 0,
 696            })));
 697        }
 698
 699        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 700            match finish_reason {
 701                "stop" => {
 702                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 703                }
 704                "tool_calls" => {
 705                    events.extend(self.process_tool_calls());
 706                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 707                }
 708                unexpected => {
 709                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 710                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 711                }
 712            }
 713        }
 714
 715        events
 716    }
 717
 718    fn process_tool_calls(
 719        &mut self,
 720    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 721        let mut results = Vec::new();
 722
 723        for (_, tool_call) in self.tool_calls_by_index.drain() {
 724            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 725                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 726                    "Received incomplete tool call: missing id or name"
 727                ))));
 728                continue;
 729            }
 730
 731            match parse_tool_arguments(&tool_call.arguments) {
 732                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 733                    LanguageModelToolUse {
 734                        id: tool_call.id.into(),
 735                        name: tool_call.name.into(),
 736                        is_input_complete: true,
 737                        input,
 738                        raw_input: tool_call.arguments,
 739                        thought_signature: None,
 740                    },
 741                ))),
 742                Err(error) => {
 743                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 744                        id: tool_call.id.into(),
 745                        tool_name: tool_call.name.into(),
 746                        raw_input: tool_call.arguments.into(),
 747                        json_parse_error: error.to_string(),
 748                    }))
 749                }
 750            }
 751        }
 752
 753        results
 754    }
 755}
 756
 757#[derive(Default)]
 758struct RawToolCall {
 759    id: String,
 760    name: String,
 761    arguments: String,
 762}
 763
 764struct ConfigurationView {
 765    api_key_editor: Entity<InputField>,
 766    state: Entity<State>,
 767    load_credentials_task: Option<Task<()>>,
 768}
 769
 770impl ConfigurationView {
 771    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 772        let api_key_editor =
 773            cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
 774
 775        cx.observe(&state, |_, _, cx| {
 776            cx.notify();
 777        })
 778        .detach();
 779
 780        let load_credentials_task = Some(cx.spawn_in(window, {
 781            let state = state.clone();
 782            async move |this, cx| {
 783                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 784                    // We don't log an error, because "not signed in" is also an error.
 785                    let _ = task.await;
 786                }
 787
 788                this.update(cx, |this, cx| {
 789                    this.load_credentials_task = None;
 790                    cx.notify();
 791                })
 792                .log_err();
 793            }
 794        }));
 795
 796        Self {
 797            api_key_editor,
 798            state,
 799            load_credentials_task,
 800        }
 801    }
 802
 803    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 804        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 805        if api_key.is_empty() {
 806            return;
 807        }
 808
 809        // url changes can cause the editor to be displayed again
 810        self.api_key_editor
 811            .update(cx, |editor, cx| editor.set_text("", window, cx));
 812
 813        let state = self.state.clone();
 814        cx.spawn_in(window, async move |_, cx| {
 815            state
 816                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 817                .await
 818        })
 819        .detach_and_log_err(cx);
 820    }
 821
 822    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 823        self.api_key_editor
 824            .update(cx, |editor, cx| editor.set_text("", window, cx));
 825
 826        let state = self.state.clone();
 827        cx.spawn_in(window, async move |_, cx| {
 828            state
 829                .update(cx, |state, cx| state.set_api_key(None, cx))
 830                .await
 831        })
 832        .detach_and_log_err(cx);
 833    }
 834
 835    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
 836        !self.state.read(cx).is_authenticated()
 837    }
 838}
 839
 840impl Render for ConfigurationView {
 841    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 842        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 843        let configured_card_label = if env_var_set {
 844            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
 845        } else {
 846            let api_url = MistralLanguageModelProvider::api_url(cx);
 847            if api_url == MISTRAL_API_URL {
 848                "API key configured".to_string()
 849            } else {
 850                format!("API key configured for {}", api_url)
 851            }
 852        };
 853
 854        if self.load_credentials_task.is_some() {
 855            div().child(Label::new("Loading credentials...")).into_any()
 856        } else if self.should_render_api_key_editor(cx) {
 857            v_flex()
 858                .size_full()
 859                .on_action(cx.listener(Self::save_api_key))
 860                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 861                .child(
 862                    List::new()
 863                        .child(
 864                            ListBulletItem::new("")
 865                                .child(Label::new("Create one by visiting"))
 866                                .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
 867                        )
 868                        .child(
 869                            ListBulletItem::new("Ensure your Mistral account has credits")
 870                        )
 871                        .child(
 872                            ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
 873                        ),
 874                )
 875                .child(self.api_key_editor.clone())
 876                .child(
 877                    Label::new(
 878                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 879                    )
 880                    .size(LabelSize::Small).color(Color::Muted),
 881                )
 882                .into_any()
 883        } else {
 884            v_flex()
 885                .size_full()
 886                .gap_1()
 887                .child(
 888                    ConfiguredApiCard::new(configured_card_label)
 889                        .disabled(env_var_set)
 890                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 891                        .when(env_var_set, |this| {
 892                            this.tooltip_label(format!(
 893                                "To reset your API key, \
 894                                unset the {API_KEY_ENV_VAR_NAME} environment variable."
 895                            ))
 896                        }),
 897                )
 898                .into_any()
 899        }
 900    }
 901}
 902
 903#[cfg(test)]
 904mod tests {
 905    use super::*;
 906    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
 907
 908    #[test]
 909    fn test_into_mistral_basic_conversion() {
 910        let request = LanguageModelRequest {
 911            messages: vec![
 912                LanguageModelRequestMessage {
 913                    role: Role::System,
 914                    content: vec![MessageContent::Text("System prompt".into())],
 915                    cache: false,
 916                    reasoning_details: None,
 917                },
 918                LanguageModelRequestMessage {
 919                    role: Role::User,
 920                    content: vec![MessageContent::Text("Hello".into())],
 921                    cache: false,
 922                    reasoning_details: None,
 923                },
 924                // should skip empty assistant messages
 925                LanguageModelRequestMessage {
 926                    role: Role::Assistant,
 927                    content: vec![MessageContent::Text("".into())],
 928                    cache: false,
 929                    reasoning_details: None,
 930                },
 931            ],
 932            temperature: Some(0.5),
 933            tools: vec![],
 934            tool_choice: None,
 935            thread_id: Some("abcdef".into()),
 936            prompt_id: None,
 937            intent: None,
 938            stop: vec![],
 939            thinking_allowed: true,
 940            thinking_effort: None,
 941            speed: Default::default(),
 942        };
 943
 944        let (mistral_request, affinity) =
 945            into_mistral(request, mistral::Model::MistralSmallLatest, None);
 946
 947        assert_eq!(mistral_request.model, "mistral-small-latest");
 948        assert_eq!(mistral_request.temperature, Some(0.5));
 949        assert_eq!(mistral_request.messages.len(), 2);
 950        assert!(mistral_request.stream);
 951        assert_eq!(affinity, Some("abcdef".into()));
 952    }
 953
 954    #[test]
 955    fn test_into_mistral_with_image() {
 956        let request = LanguageModelRequest {
 957            messages: vec![LanguageModelRequestMessage {
 958                role: Role::User,
 959                content: vec![
 960                    MessageContent::Text("What's in this image?".into()),
 961                    MessageContent::Image(LanguageModelImage {
 962                        source: "base64data".into(),
 963                        size: None,
 964                    }),
 965                ],
 966                cache: false,
 967                reasoning_details: None,
 968            }],
 969            tools: vec![],
 970            tool_choice: None,
 971            temperature: None,
 972            thread_id: None,
 973            prompt_id: None,
 974            intent: None,
 975            stop: vec![],
 976            thinking_allowed: true,
 977            thinking_effort: None,
 978            speed: None,
 979        };
 980
 981        let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
 982
 983        assert_eq!(mistral_request.messages.len(), 1);
 984        assert!(matches!(
 985            &mistral_request.messages[0],
 986            mistral::RequestMessage::User {
 987                content: mistral::MessageContent::Multipart { .. }
 988            }
 989        ));
 990
 991        if let mistral::RequestMessage::User {
 992            content: mistral::MessageContent::Multipart { content },
 993        } = &mistral_request.messages[0]
 994        {
 995            assert_eq!(content.len(), 2);
 996            assert!(matches!(
 997                &content[0],
 998                mistral::MessagePart::Text { text } if text == "What's in this image?"
 999            ));
1000            assert!(matches!(
1001                &content[1],
1002                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1003            ));
1004        }
1005    }
1006}