mistral.rs

   1use anyhow::{Result, anyhow};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4
   5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
   6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
   7use http_client::HttpClient;
   8use language_model::{
   9    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  10    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  12    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
  13    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
  14};
  15pub use mistral::{MISTRAL_API_URL, StreamResponse};
  16pub use settings::MistralAvailableModel as AvailableModel;
  17use settings::{Settings, SettingsStore};
  18use std::collections::HashMap;
  19use std::pin::Pin;
  20use std::sync::{Arc, LazyLock};
  21use strum::IntoEnumIterator;
  22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  23use ui_input::InputField;
  24use util::ResultExt;
  25
  26use language_model::util::{fix_streamed_json, parse_tool_arguments};
  27
  28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  30
  31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
  32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  33
  34#[derive(Default, Clone, Debug, PartialEq)]
  35pub struct MistralSettings {
  36    pub api_url: String,
  37    pub available_models: Vec<AvailableModel>,
  38}
  39
  40pub struct MistralLanguageModelProvider {
  41    http_client: Arc<dyn HttpClient>,
  42    pub state: Entity<State>,
  43}
  44
  45pub struct State {
  46    api_key_state: ApiKeyState,
  47    credentials_provider: Arc<dyn CredentialsProvider>,
  48}
  49
  50impl State {
  51    fn is_authenticated(&self) -> bool {
  52        self.api_key_state.has_key()
  53    }
  54
  55    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  56        let credentials_provider = self.credentials_provider.clone();
  57        let api_url = MistralLanguageModelProvider::api_url(cx);
  58        self.api_key_state.store(
  59            api_url,
  60            api_key,
  61            |this| &mut this.api_key_state,
  62            credentials_provider,
  63            cx,
  64        )
  65    }
  66
  67    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  68        let credentials_provider = self.credentials_provider.clone();
  69        let api_url = MistralLanguageModelProvider::api_url(cx);
  70        self.api_key_state.load_if_needed(
  71            api_url,
  72            |this| &mut this.api_key_state,
  73            credentials_provider,
  74            cx,
  75        )
  76    }
  77}
  78
  79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
  80
  81impl Global for GlobalMistralLanguageModelProvider {}
  82
  83impl MistralLanguageModelProvider {
  84    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
  85        cx.try_global::<GlobalMistralLanguageModelProvider>()
  86            .map(|this| &this.0)
  87    }
  88
  89    pub fn global(
  90        http_client: Arc<dyn HttpClient>,
  91        credentials_provider: Arc<dyn CredentialsProvider>,
  92        cx: &mut App,
  93    ) -> Arc<Self> {
  94        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
  95            return this.0.clone();
  96        }
  97        let state = cx.new(|cx| {
  98            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  99                let credentials_provider = this.credentials_provider.clone();
 100                let api_url = Self::api_url(cx);
 101                this.api_key_state.handle_url_change(
 102                    api_url,
 103                    |this| &mut this.api_key_state,
 104                    credentials_provider,
 105                    cx,
 106                );
 107                cx.notify();
 108            })
 109            .detach();
 110            State {
 111                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 112                credentials_provider,
 113            }
 114        });
 115
 116        let this = Arc::new(Self { http_client, state });
 117        cx.set_global(GlobalMistralLanguageModelProvider(this));
 118        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
 119    }
 120
 121    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 122        Arc::new(MistralLanguageModel {
 123            id: LanguageModelId::from(model.id().to_string()),
 124            model,
 125            state: self.state.clone(),
 126            http_client: self.http_client.clone(),
 127            request_limiter: RateLimiter::new(4),
 128        })
 129    }
 130
 131    fn settings(cx: &App) -> &MistralSettings {
 132        &crate::AllLanguageModelSettings::get_global(cx).mistral
 133    }
 134
 135    pub fn api_url(cx: &App) -> SharedString {
 136        let api_url = &Self::settings(cx).api_url;
 137        if api_url.is_empty() {
 138            mistral::MISTRAL_API_URL.into()
 139        } else {
 140            SharedString::new(api_url.as_str())
 141        }
 142    }
 143}
 144
 145impl LanguageModelProviderState for MistralLanguageModelProvider {
 146    type ObservableEntity = State;
 147
 148    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 149        Some(self.state.clone())
 150    }
 151}
 152
 153impl LanguageModelProvider for MistralLanguageModelProvider {
 154    fn id(&self) -> LanguageModelProviderId {
 155        PROVIDER_ID
 156    }
 157
 158    fn name(&self) -> LanguageModelProviderName {
 159        PROVIDER_NAME
 160    }
 161
 162    fn icon(&self) -> IconOrSvg {
 163        IconOrSvg::Icon(IconName::AiMistral)
 164    }
 165
 166    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 167        Some(self.create_language_model(mistral::Model::default()))
 168    }
 169
 170    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 171        Some(self.create_language_model(mistral::Model::default_fast()))
 172    }
 173
 174    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 175        let mut models = BTreeMap::default();
 176
 177        // Add base models from mistral::Model::iter()
 178        for model in mistral::Model::iter() {
 179            if !matches!(model, mistral::Model::Custom { .. }) {
 180                models.insert(model.id().to_string(), model);
 181            }
 182        }
 183
 184        // Override with available models from settings
 185        for model in &Self::settings(cx).available_models {
 186            models.insert(
 187                model.name.clone(),
 188                mistral::Model::Custom {
 189                    name: model.name.clone(),
 190                    display_name: model.display_name.clone(),
 191                    max_tokens: model.max_tokens,
 192                    max_output_tokens: model.max_output_tokens,
 193                    max_completion_tokens: model.max_completion_tokens,
 194                    supports_tools: model.supports_tools,
 195                    supports_images: model.supports_images,
 196                    supports_thinking: model.supports_thinking,
 197                },
 198            );
 199        }
 200
 201        models
 202            .into_values()
 203            .map(|model| {
 204                Arc::new(MistralLanguageModel {
 205                    id: LanguageModelId::from(model.id().to_string()),
 206                    model,
 207                    state: self.state.clone(),
 208                    http_client: self.http_client.clone(),
 209                    request_limiter: RateLimiter::new(4),
 210                }) as Arc<dyn LanguageModel>
 211            })
 212            .collect()
 213    }
 214
 215    fn is_authenticated(&self, cx: &App) -> bool {
 216        self.state.read(cx).is_authenticated()
 217    }
 218
 219    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 220        self.state.update(cx, |state, cx| state.authenticate(cx))
 221    }
 222
 223    fn configuration_view(
 224        &self,
 225        _target_agent: language_model::ConfigurationViewTargetAgent,
 226        window: &mut Window,
 227        cx: &mut App,
 228    ) -> AnyView {
 229        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 230            .into()
 231    }
 232
 233    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 234        self.state
 235            .update(cx, |state, cx| state.set_api_key(None, cx))
 236    }
 237}
 238
 239pub struct MistralLanguageModel {
 240    id: LanguageModelId,
 241    model: mistral::Model,
 242    state: Entity<State>,
 243    http_client: Arc<dyn HttpClient>,
 244    request_limiter: RateLimiter,
 245}
 246
 247impl MistralLanguageModel {
 248    fn stream_completion(
 249        &self,
 250        request: mistral::Request,
 251        affinity: Option<String>,
 252        cx: &AsyncApp,
 253    ) -> BoxFuture<
 254        'static,
 255        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 256    > {
 257        let http_client = self.http_client.clone();
 258
 259        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 260            let api_url = MistralLanguageModelProvider::api_url(cx);
 261            (state.api_key_state.key(&api_url), api_url)
 262        });
 263
 264        let future = self.request_limiter.stream(async move {
 265            let Some(api_key) = api_key else {
 266                return Err(LanguageModelCompletionError::NoApiKey {
 267                    provider: PROVIDER_NAME,
 268                });
 269            };
 270            let request = mistral::stream_completion(
 271                http_client.as_ref(),
 272                &api_url,
 273                &api_key,
 274                request,
 275                affinity,
 276            );
 277            let response = request.await?;
 278            Ok(response)
 279        });
 280
 281        async move { Ok(future.await?.boxed()) }.boxed()
 282    }
 283}
 284
 285impl LanguageModel for MistralLanguageModel {
 286    fn id(&self) -> LanguageModelId {
 287        self.id.clone()
 288    }
 289
 290    fn name(&self) -> LanguageModelName {
 291        LanguageModelName::from(self.model.display_name().to_string())
 292    }
 293
 294    fn provider_id(&self) -> LanguageModelProviderId {
 295        PROVIDER_ID
 296    }
 297
 298    fn provider_name(&self) -> LanguageModelProviderName {
 299        PROVIDER_NAME
 300    }
 301
 302    fn supports_tools(&self) -> bool {
 303        self.model.supports_tools()
 304    }
 305
 306    fn supports_streaming_tools(&self) -> bool {
 307        true
 308    }
 309
 310    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 311        self.model.supports_tools()
 312    }
 313
 314    fn supports_images(&self) -> bool {
 315        self.model.supports_images()
 316    }
 317
 318    fn telemetry_id(&self) -> String {
 319        format!("mistral/{}", self.model.id())
 320    }
 321
 322    fn max_token_count(&self) -> u64 {
 323        self.model.max_token_count()
 324    }
 325
 326    fn max_output_tokens(&self) -> Option<u64> {
 327        self.model.max_output_tokens()
 328    }
 329
 330    fn stream_completion(
 331        &self,
 332        request: LanguageModelRequest,
 333        cx: &AsyncApp,
 334    ) -> BoxFuture<
 335        'static,
 336        Result<
 337            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 338            LanguageModelCompletionError,
 339        >,
 340    > {
 341        let (request, affinity) =
 342            into_mistral(request, self.model.clone(), self.max_output_tokens());
 343        let stream = self.stream_completion(request, affinity, cx);
 344
 345        async move {
 346            let stream = stream.await?;
 347            let mapper = MistralEventMapper::new();
 348            Ok(mapper.map_stream(stream).boxed())
 349        }
 350        .boxed()
 351    }
 352}
 353
 354pub fn into_mistral(
 355    request: LanguageModelRequest,
 356    model: mistral::Model,
 357    max_output_tokens: Option<u64>,
 358) -> (mistral::Request, Option<String>) {
 359    let stream = true;
 360
 361    let mut messages = Vec::new();
 362    for message in &request.messages {
 363        match message.role {
 364            Role::User => {
 365                let mut message_content = mistral::MessageContent::empty();
 366                for content in &message.content {
 367                    match content {
 368                        MessageContent::Text(text) => {
 369                            message_content
 370                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 371                        }
 372                        MessageContent::Image(image_content) => {
 373                            if model.supports_images() {
 374                                message_content.push_part(mistral::MessagePart::ImageUrl {
 375                                    image_url: image_content.to_base64_url(),
 376                                });
 377                            }
 378                        }
 379                        MessageContent::Thinking { text, .. } => {
 380                            if model.supports_thinking() {
 381                                message_content.push_part(mistral::MessagePart::Thinking {
 382                                    thinking: vec![mistral::ThinkingPart::Text {
 383                                        text: text.clone(),
 384                                    }],
 385                                });
 386                            }
 387                        }
 388                        MessageContent::RedactedThinking(_) => {}
 389                        MessageContent::ToolUse(_) => {
 390                            // Tool use is not supported in User messages for Mistral
 391                        }
 392                        MessageContent::ToolResult(tool_result) => {
 393                            let tool_content = match &tool_result.content {
 394                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 395                                LanguageModelToolResultContent::Image(_) => {
 396                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 397                                }
 398                            };
 399                            messages.push(mistral::RequestMessage::Tool {
 400                                content: tool_content,
 401                                tool_call_id: tool_result.tool_use_id.to_string(),
 402                            });
 403                        }
 404                    }
 405                }
 406                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 407                {
 408                    messages.push(mistral::RequestMessage::User {
 409                        content: message_content,
 410                    });
 411                }
 412            }
 413            Role::Assistant => {
 414                for content in &message.content {
 415                    match content {
 416                        MessageContent::Text(text) if text.is_empty() => {
 417                            // Mistral API returns a 400 if there's neither content nor tool_calls
 418                        }
 419                        MessageContent::Text(text) => {
 420                            messages.push(mistral::RequestMessage::Assistant {
 421                                content: Some(mistral::MessageContent::Plain {
 422                                    content: text.clone(),
 423                                }),
 424                                tool_calls: Vec::new(),
 425                            });
 426                        }
 427                        MessageContent::Thinking { text, .. } => {
 428                            if model.supports_thinking() {
 429                                messages.push(mistral::RequestMessage::Assistant {
 430                                    content: Some(mistral::MessageContent::Multipart {
 431                                        content: vec![mistral::MessagePart::Thinking {
 432                                            thinking: vec![mistral::ThinkingPart::Text {
 433                                                text: text.clone(),
 434                                            }],
 435                                        }],
 436                                    }),
 437                                    tool_calls: Vec::new(),
 438                                });
 439                            }
 440                        }
 441                        MessageContent::RedactedThinking(_) => {}
 442                        MessageContent::Image(_) => {}
 443                        MessageContent::ToolUse(tool_use) => {
 444                            let tool_call = mistral::ToolCall {
 445                                id: tool_use.id.to_string(),
 446                                content: mistral::ToolCallContent::Function {
 447                                    function: mistral::FunctionContent {
 448                                        name: tool_use.name.to_string(),
 449                                        arguments: serde_json::to_string(&tool_use.input)
 450                                            .unwrap_or_default(),
 451                                    },
 452                                },
 453                            };
 454
 455                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 456                                messages.last_mut()
 457                            {
 458                                tool_calls.push(tool_call);
 459                            } else {
 460                                messages.push(mistral::RequestMessage::Assistant {
 461                                    content: None,
 462                                    tool_calls: vec![tool_call],
 463                                });
 464                            }
 465                        }
 466                        MessageContent::ToolResult(_) => {
 467                            // Tool results are not supported in Assistant messages
 468                        }
 469                    }
 470                }
 471            }
 472            Role::System => {
 473                for content in &message.content {
 474                    match content {
 475                        MessageContent::Text(text) => {
 476                            messages.push(mistral::RequestMessage::System {
 477                                content: mistral::MessageContent::Plain {
 478                                    content: text.clone(),
 479                                },
 480                            });
 481                        }
 482                        MessageContent::Thinking { text, .. } => {
 483                            if model.supports_thinking() {
 484                                messages.push(mistral::RequestMessage::System {
 485                                    content: mistral::MessageContent::Multipart {
 486                                        content: vec![mistral::MessagePart::Thinking {
 487                                            thinking: vec![mistral::ThinkingPart::Text {
 488                                                text: text.clone(),
 489                                            }],
 490                                        }],
 491                                    },
 492                                });
 493                            }
 494                        }
 495                        MessageContent::RedactedThinking(_) => {}
 496                        MessageContent::Image(_)
 497                        | MessageContent::ToolUse(_)
 498                        | MessageContent::ToolResult(_) => {
 499                            // Images and tools are not supported in System messages
 500                        }
 501                    }
 502                }
 503            }
 504        }
 505    }
 506
 507    (
 508        mistral::Request {
 509            model: model.id().to_string(),
 510            messages,
 511            stream,
 512            stream_options: if stream {
 513                Some(mistral::StreamOptions {
 514                    stream_tool_calls: Some(true),
 515                })
 516            } else {
 517                None
 518            },
 519            max_tokens: max_output_tokens,
 520            temperature: request.temperature,
 521            response_format: None,
 522            tool_choice: match request.tool_choice {
 523                Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 524                    Some(mistral::ToolChoice::Auto)
 525                }
 526                Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 527                    Some(mistral::ToolChoice::Any)
 528                }
 529                Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 530                _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 531                _ => None,
 532            },
 533            parallel_tool_calls: if !request.tools.is_empty() {
 534                Some(false)
 535            } else {
 536                None
 537            },
 538            tools: request
 539                .tools
 540                .into_iter()
 541                .map(|tool| mistral::ToolDefinition::Function {
 542                    function: mistral::FunctionDefinition {
 543                        name: tool.name,
 544                        description: Some(tool.description),
 545                        parameters: Some(tool.input_schema),
 546                    },
 547                })
 548                .collect(),
 549        },
 550        request.thread_id,
 551    )
 552}
 553
 554pub struct MistralEventMapper {
 555    tool_calls_by_index: HashMap<usize, RawToolCall>,
 556}
 557
 558impl MistralEventMapper {
 559    pub fn new() -> Self {
 560        Self {
 561            tool_calls_by_index: HashMap::default(),
 562        }
 563    }
 564
 565    pub fn map_stream(
 566        mut self,
 567        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 568    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 569    {
 570        events.flat_map(move |event| {
 571            futures::stream::iter(match event {
 572                Ok(event) => self.map_event(event),
 573                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 574            })
 575        })
 576    }
 577
 578    pub fn map_event(
 579        &mut self,
 580        event: mistral::StreamResponse,
 581    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 582        let Some(choice) = event.choices.first() else {
 583            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 584                "Response contained no choices"
 585            )))];
 586        };
 587
 588        let mut events = Vec::new();
 589        if let Some(content) = choice.delta.content.as_ref() {
 590            match content {
 591                mistral::MessageContentDelta::Text(text) => {
 592                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 593                }
 594                mistral::MessageContentDelta::Parts(parts) => {
 595                    for part in parts {
 596                        match part {
 597                            mistral::MessagePart::Text { text } => {
 598                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 599                            }
 600                            mistral::MessagePart::Thinking { thinking } => {
 601                                for tp in thinking.iter().cloned() {
 602                                    match tp {
 603                                        mistral::ThinkingPart::Text { text } => {
 604                                            events.push(Ok(
 605                                                LanguageModelCompletionEvent::Thinking {
 606                                                    text,
 607                                                    signature: None,
 608                                                },
 609                                            ));
 610                                        }
 611                                    }
 612                                }
 613                            }
 614                            mistral::MessagePart::ImageUrl { .. } => {
 615                                // We currently don't emit a separate event for images in responses.
 616                            }
 617                        }
 618                    }
 619                }
 620            }
 621        }
 622
 623        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 624            for tool_call in tool_calls {
 625                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 626
 627                if let Some(tool_id) = tool_call.id.clone()
 628                    && !tool_id.is_empty()
 629                    && tool_id != "null"
 630                {
 631                    entry.id = tool_id;
 632                }
 633
 634                if let Some(function) = tool_call.function.as_ref() {
 635                    if let Some(name) = function.name.clone()
 636                        && !name.is_empty()
 637                    {
 638                        entry.name = name;
 639                    }
 640
 641                    if let Some(arguments) = function.arguments.clone() {
 642                        entry.arguments.push_str(&arguments);
 643                    }
 644                }
 645
 646                if !entry.id.is_empty() && !entry.name.is_empty() {
 647                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
 648                        &fix_streamed_json(&entry.arguments),
 649                    ) {
 650                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 651                            LanguageModelToolUse {
 652                                id: entry.id.clone().into(),
 653                                name: entry.name.as_str().into(),
 654                                is_input_complete: false,
 655                                input,
 656                                raw_input: entry.arguments.clone(),
 657                                thought_signature: None,
 658                            },
 659                        )));
 660                    }
 661                }
 662            }
 663        }
 664
 665        if let Some(usage) = event.usage {
 666            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 667                input_tokens: usage.prompt_tokens,
 668                output_tokens: usage.completion_tokens,
 669                cache_creation_input_tokens: 0,
 670                cache_read_input_tokens: 0,
 671            })));
 672        }
 673
 674        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 675            match finish_reason {
 676                "stop" => {
 677                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 678                }
 679                "tool_calls" => {
 680                    events.extend(self.process_tool_calls());
 681                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 682                }
 683                unexpected => {
 684                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 685                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 686                }
 687            }
 688        }
 689
 690        events
 691    }
 692
 693    fn process_tool_calls(
 694        &mut self,
 695    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 696        let mut results = Vec::new();
 697
 698        for (_, tool_call) in self.tool_calls_by_index.drain() {
 699            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 700                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 701                    "Received incomplete tool call: missing id or name"
 702                ))));
 703                continue;
 704            }
 705
 706            match parse_tool_arguments(&tool_call.arguments) {
 707                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 708                    LanguageModelToolUse {
 709                        id: tool_call.id.into(),
 710                        name: tool_call.name.into(),
 711                        is_input_complete: true,
 712                        input,
 713                        raw_input: tool_call.arguments,
 714                        thought_signature: None,
 715                    },
 716                ))),
 717                Err(error) => {
 718                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 719                        id: tool_call.id.into(),
 720                        tool_name: tool_call.name.into(),
 721                        raw_input: tool_call.arguments.into(),
 722                        json_parse_error: error.to_string(),
 723                    }))
 724                }
 725            }
 726        }
 727
 728        results
 729    }
 730}
 731
 732#[derive(Default)]
 733struct RawToolCall {
 734    id: String,
 735    name: String,
 736    arguments: String,
 737}
 738
 739struct ConfigurationView {
 740    api_key_editor: Entity<InputField>,
 741    state: Entity<State>,
 742    load_credentials_task: Option<Task<()>>,
 743}
 744
 745impl ConfigurationView {
 746    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 747        let api_key_editor =
 748            cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
 749
 750        cx.observe(&state, |_, _, cx| {
 751            cx.notify();
 752        })
 753        .detach();
 754
 755        let load_credentials_task = Some(cx.spawn_in(window, {
 756            let state = state.clone();
 757            async move |this, cx| {
 758                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 759                    // We don't log an error, because "not signed in" is also an error.
 760                    let _ = task.await;
 761                }
 762
 763                this.update(cx, |this, cx| {
 764                    this.load_credentials_task = None;
 765                    cx.notify();
 766                })
 767                .log_err();
 768            }
 769        }));
 770
 771        Self {
 772            api_key_editor,
 773            state,
 774            load_credentials_task,
 775        }
 776    }
 777
 778    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 779        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 780        if api_key.is_empty() {
 781            return;
 782        }
 783
 784        // url changes can cause the editor to be displayed again
 785        self.api_key_editor
 786            .update(cx, |editor, cx| editor.set_text("", window, cx));
 787
 788        let state = self.state.clone();
 789        cx.spawn_in(window, async move |_, cx| {
 790            state
 791                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 792                .await
 793        })
 794        .detach_and_log_err(cx);
 795    }
 796
 797    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 798        self.api_key_editor
 799            .update(cx, |editor, cx| editor.set_text("", window, cx));
 800
 801        let state = self.state.clone();
 802        cx.spawn_in(window, async move |_, cx| {
 803            state
 804                .update(cx, |state, cx| state.set_api_key(None, cx))
 805                .await
 806        })
 807        .detach_and_log_err(cx);
 808    }
 809
 810    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
 811        !self.state.read(cx).is_authenticated()
 812    }
 813}
 814
 815impl Render for ConfigurationView {
 816    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 817        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 818        let configured_card_label = if env_var_set {
 819            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
 820        } else {
 821            let api_url = MistralLanguageModelProvider::api_url(cx);
 822            if api_url == MISTRAL_API_URL {
 823                "API key configured".to_string()
 824            } else {
 825                format!("API key configured for {}", api_url)
 826            }
 827        };
 828
 829        if self.load_credentials_task.is_some() {
 830            div().child(Label::new("Loading credentials...")).into_any()
 831        } else if self.should_render_api_key_editor(cx) {
 832            v_flex()
 833                .size_full()
 834                .on_action(cx.listener(Self::save_api_key))
 835                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 836                .child(
 837                    List::new()
 838                        .child(
 839                            ListBulletItem::new("")
 840                                .child(Label::new("Create one by visiting"))
 841                                .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
 842                        )
 843                        .child(
 844                            ListBulletItem::new("Ensure your Mistral account has credits")
 845                        )
 846                        .child(
 847                            ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
 848                        ),
 849                )
 850                .child(self.api_key_editor.clone())
 851                .child(
 852                    Label::new(
 853                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 854                    )
 855                    .size(LabelSize::Small).color(Color::Muted),
 856                )
 857                .into_any()
 858        } else {
 859            v_flex()
 860                .size_full()
 861                .gap_1()
 862                .child(
 863                    ConfiguredApiCard::new(configured_card_label)
 864                        .disabled(env_var_set)
 865                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 866                        .when(env_var_set, |this| {
 867                            this.tooltip_label(format!(
 868                                "To reset your API key, \
 869                                unset the {API_KEY_ENV_VAR_NAME} environment variable."
 870                            ))
 871                        }),
 872                )
 873                .into_any()
 874        }
 875    }
 876}
 877
 878#[cfg(test)]
 879mod tests {
 880    use super::*;
 881    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
 882
 883    fn tool_call_chunk(
 884        id: Option<&str>,
 885        name: Option<&str>,
 886        arguments: Option<&str>,
 887        finish_reason: Option<&str>,
 888    ) -> mistral::StreamResponse {
 889        mistral::StreamResponse {
 890            id: "resp".into(),
 891            object: "chat.completion.chunk".into(),
 892            created: 0,
 893            model: "test".into(),
 894            choices: vec![mistral::StreamChoice {
 895                index: 0,
 896                delta: mistral::StreamDelta {
 897                    role: None,
 898                    content: None,
 899                    tool_calls: if finish_reason.is_some() {
 900                        None
 901                    } else {
 902                        Some(vec![mistral::ToolCallChunk {
 903                            index: 0,
 904                            id: id.map(Into::into),
 905                            function: Some(mistral::FunctionChunk {
 906                                name: name.map(Into::into),
 907                                arguments: arguments.map(Into::into),
 908                            }),
 909                        }])
 910                    },
 911                },
 912                finish_reason: finish_reason.map(Into::into),
 913            }],
 914            usage: None,
 915        }
 916    }
 917
 918    #[test]
 919    fn test_streaming_tool_call_ignores_null_id() {
 920        // Mistral's streaming API sometimes sends `"id": "null"` in continuation chunks.
 921        let mut mapper = MistralEventMapper::new();
 922
 923        mapper.map_event(tool_call_chunk(
 924            Some("real_id_123"),
 925            Some("read_file"),
 926            Some("{\"path\":"),
 927            None,
 928        ));
 929        mapper.map_event(tool_call_chunk(
 930            Some("null"),
 931            None,
 932            Some("\"a.txt\"}"),
 933            None,
 934        ));
 935        let events = mapper.map_event(tool_call_chunk(None, None, None, Some("tool_calls")));
 936
 937        let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] else {
 938            panic!("Expected first event to be ToolUse, got: {:?}", events[0]);
 939        };
 940
 941        assert_eq!(tool_use.id.to_string(), "real_id_123");
 942        assert_eq!(tool_use.name.as_ref(), "read_file");
 943        assert_eq!(tool_use.input, serde_json::json!({"path": "a.txt"}));
 944    }
 945
 946    #[test]
 947    fn test_into_mistral_basic_conversion() {
 948        let request = LanguageModelRequest {
 949            messages: vec![
 950                LanguageModelRequestMessage {
 951                    role: Role::System,
 952                    content: vec![MessageContent::Text("System prompt".into())],
 953                    cache: false,
 954                    reasoning_details: None,
 955                },
 956                LanguageModelRequestMessage {
 957                    role: Role::User,
 958                    content: vec![MessageContent::Text("Hello".into())],
 959                    cache: false,
 960                    reasoning_details: None,
 961                },
 962                // should skip empty assistant messages
 963                LanguageModelRequestMessage {
 964                    role: Role::Assistant,
 965                    content: vec![MessageContent::Text("".into())],
 966                    cache: false,
 967                    reasoning_details: None,
 968                },
 969            ],
 970            temperature: Some(0.5),
 971            tools: vec![],
 972            tool_choice: None,
 973            thread_id: Some("abcdef".into()),
 974            prompt_id: None,
 975            intent: None,
 976            stop: vec![],
 977            thinking_allowed: true,
 978            thinking_effort: None,
 979            speed: Default::default(),
 980        };
 981
 982        let (mistral_request, affinity) =
 983            into_mistral(request, mistral::Model::MistralSmallLatest, None);
 984
 985        assert_eq!(mistral_request.model, "mistral-small-latest");
 986        assert_eq!(mistral_request.temperature, Some(0.5));
 987        assert_eq!(mistral_request.messages.len(), 2);
 988        assert!(mistral_request.stream);
 989        assert_eq!(affinity, Some("abcdef".into()));
 990    }
 991
 992    #[test]
 993    fn test_into_mistral_with_image() {
 994        let request = LanguageModelRequest {
 995            messages: vec![LanguageModelRequestMessage {
 996                role: Role::User,
 997                content: vec![
 998                    MessageContent::Text("What's in this image?".into()),
 999                    MessageContent::Image(LanguageModelImage {
1000                        source: "base64data".into(),
1001                        size: None,
1002                    }),
1003                ],
1004                cache: false,
1005                reasoning_details: None,
1006            }],
1007            tools: vec![],
1008            tool_choice: None,
1009            temperature: None,
1010            thread_id: None,
1011            prompt_id: None,
1012            intent: None,
1013            stop: vec![],
1014            thinking_allowed: true,
1015            thinking_effort: None,
1016            speed: None,
1017        };
1018
1019        let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1020
1021        assert_eq!(mistral_request.messages.len(), 1);
1022        assert!(matches!(
1023            &mistral_request.messages[0],
1024            mistral::RequestMessage::User {
1025                content: mistral::MessageContent::Multipart { .. }
1026            }
1027        ));
1028
1029        if let mistral::RequestMessage::User {
1030            content: mistral::MessageContent::Multipart { content },
1031        } = &mistral_request.messages[0]
1032        {
1033            assert_eq!(content.len(), 2);
1034            assert!(matches!(
1035                &content[0],
1036                mistral::MessagePart::Text { text } if text == "What's in this image?"
1037            ));
1038            assert!(matches!(
1039                &content[1],
1040                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1041            ));
1042        }
1043    }
1044}