mistral.rs

   1use anyhow::{Result, anyhow};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4
   5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
   6use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
   7use http_client::HttpClient;
   8use language_model::{
   9    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  10    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  12    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
  13    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
  14};
  15pub use mistral::{MISTRAL_API_URL, StreamResponse};
  16pub use settings::MistralAvailableModel as AvailableModel;
  17use settings::{Settings, SettingsStore};
  18use std::collections::HashMap;
  19use std::pin::Pin;
  20use std::sync::{Arc, LazyLock};
  21use strum::IntoEnumIterator;
  22use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  23use ui_input::InputField;
  24use util::ResultExt;
  25
  26use language_model::util::{fix_streamed_json, parse_tool_arguments};
  27
  28const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  29const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  30
  31const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
  32static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  33
  34#[derive(Default, Clone, Debug, PartialEq)]
  35pub struct MistralSettings {
  36    pub api_url: String,
  37    pub available_models: Vec<AvailableModel>,
  38}
  39
  40pub struct MistralLanguageModelProvider {
  41    http_client: Arc<dyn HttpClient>,
  42    pub state: Entity<State>,
  43}
  44
  45pub struct State {
  46    api_key_state: ApiKeyState,
  47    credentials_provider: Arc<dyn CredentialsProvider>,
  48}
  49
  50impl State {
  51    fn is_authenticated(&self) -> bool {
  52        self.api_key_state.has_key()
  53    }
  54
  55    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  56        let credentials_provider = self.credentials_provider.clone();
  57        let api_url = MistralLanguageModelProvider::api_url(cx);
  58        self.api_key_state.store(
  59            api_url,
  60            api_key,
  61            |this| &mut this.api_key_state,
  62            credentials_provider,
  63            cx,
  64        )
  65    }
  66
  67    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  68        let credentials_provider = self.credentials_provider.clone();
  69        let api_url = MistralLanguageModelProvider::api_url(cx);
  70        self.api_key_state.load_if_needed(
  71            api_url,
  72            |this| &mut this.api_key_state,
  73            credentials_provider,
  74            cx,
  75        )
  76    }
  77}
  78
  79struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
  80
  81impl Global for GlobalMistralLanguageModelProvider {}
  82
  83impl MistralLanguageModelProvider {
  84    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
  85        cx.try_global::<GlobalMistralLanguageModelProvider>()
  86            .map(|this| &this.0)
  87    }
  88
  89    pub fn global(
  90        http_client: Arc<dyn HttpClient>,
  91        credentials_provider: Arc<dyn CredentialsProvider>,
  92        cx: &mut App,
  93    ) -> Arc<Self> {
  94        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
  95            return this.0.clone();
  96        }
  97        let state = cx.new(|cx| {
  98            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  99                let credentials_provider = this.credentials_provider.clone();
 100                let api_url = Self::api_url(cx);
 101                this.api_key_state.handle_url_change(
 102                    api_url,
 103                    |this| &mut this.api_key_state,
 104                    credentials_provider,
 105                    cx,
 106                );
 107                cx.notify();
 108            })
 109            .detach();
 110            State {
 111                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 112                credentials_provider,
 113            }
 114        });
 115
 116        let this = Arc::new(Self { http_client, state });
 117        cx.set_global(GlobalMistralLanguageModelProvider(this));
 118        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
 119    }
 120
 121    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 122        Arc::new(MistralLanguageModel {
 123            id: LanguageModelId::from(model.id().to_string()),
 124            model,
 125            state: self.state.clone(),
 126            http_client: self.http_client.clone(),
 127            request_limiter: RateLimiter::new(4),
 128        })
 129    }
 130
 131    fn settings(cx: &App) -> &MistralSettings {
 132        &crate::AllLanguageModelSettings::get_global(cx).mistral
 133    }
 134
 135    pub fn api_url(cx: &App) -> SharedString {
 136        let api_url = &Self::settings(cx).api_url;
 137        if api_url.is_empty() {
 138            mistral::MISTRAL_API_URL.into()
 139        } else {
 140            SharedString::new(api_url.as_str())
 141        }
 142    }
 143}
 144
 145impl LanguageModelProviderState for MistralLanguageModelProvider {
 146    type ObservableEntity = State;
 147
 148    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 149        Some(self.state.clone())
 150    }
 151}
 152
 153impl LanguageModelProvider for MistralLanguageModelProvider {
 154    fn id(&self) -> LanguageModelProviderId {
 155        PROVIDER_ID
 156    }
 157
 158    fn name(&self) -> LanguageModelProviderName {
 159        PROVIDER_NAME
 160    }
 161
 162    fn icon(&self) -> IconOrSvg {
 163        IconOrSvg::Icon(IconName::AiMistral)
 164    }
 165
 166    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 167        Some(self.create_language_model(mistral::Model::default()))
 168    }
 169
 170    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 171        Some(self.create_language_model(mistral::Model::default_fast()))
 172    }
 173
 174    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 175        let mut models = BTreeMap::default();
 176
 177        // Add base models from mistral::Model::iter()
 178        for model in mistral::Model::iter() {
 179            if !matches!(model, mistral::Model::Custom { .. }) {
 180                models.insert(model.id().to_string(), model);
 181            }
 182        }
 183
 184        // Override with available models from settings
 185        for model in &Self::settings(cx).available_models {
 186            models.insert(
 187                model.name.clone(),
 188                mistral::Model::Custom {
 189                    name: model.name.clone(),
 190                    display_name: model.display_name.clone(),
 191                    max_tokens: model.max_tokens,
 192                    max_output_tokens: model.max_output_tokens,
 193                    max_completion_tokens: model.max_completion_tokens,
 194                    supports_tools: model.supports_tools,
 195                    supports_images: model.supports_images,
 196                    supports_thinking: model.supports_thinking,
 197                },
 198            );
 199        }
 200
 201        models
 202            .into_values()
 203            .map(|model| {
 204                Arc::new(MistralLanguageModel {
 205                    id: LanguageModelId::from(model.id().to_string()),
 206                    model,
 207                    state: self.state.clone(),
 208                    http_client: self.http_client.clone(),
 209                    request_limiter: RateLimiter::new(4),
 210                }) as Arc<dyn LanguageModel>
 211            })
 212            .collect()
 213    }
 214
 215    fn is_authenticated(&self, cx: &App) -> bool {
 216        self.state.read(cx).is_authenticated()
 217    }
 218
 219    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 220        self.state.update(cx, |state, cx| state.authenticate(cx))
 221    }
 222
 223    fn configuration_view(
 224        &self,
 225        _target_agent: language_model::ConfigurationViewTargetAgent,
 226        window: &mut Window,
 227        cx: &mut App,
 228    ) -> AnyView {
 229        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 230            .into()
 231    }
 232
 233    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 234        self.state
 235            .update(cx, |state, cx| state.set_api_key(None, cx))
 236    }
 237}
 238
 239pub struct MistralLanguageModel {
 240    id: LanguageModelId,
 241    model: mistral::Model,
 242    state: Entity<State>,
 243    http_client: Arc<dyn HttpClient>,
 244    request_limiter: RateLimiter,
 245}
 246
 247impl MistralLanguageModel {
 248    fn stream_completion(
 249        &self,
 250        request: mistral::Request,
 251        affinity: Option<String>,
 252        cx: &AsyncApp,
 253    ) -> BoxFuture<
 254        'static,
 255        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 256    > {
 257        let http_client = self.http_client.clone();
 258
 259        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 260            let api_url = MistralLanguageModelProvider::api_url(cx);
 261            (state.api_key_state.key(&api_url), api_url)
 262        });
 263
 264        let future = self.request_limiter.stream(async move {
 265            let Some(api_key) = api_key else {
 266                return Err(LanguageModelCompletionError::NoApiKey {
 267                    provider: PROVIDER_NAME,
 268                });
 269            };
 270            let request = mistral::stream_completion(
 271                http_client.as_ref(),
 272                &api_url,
 273                &api_key,
 274                request,
 275                affinity,
 276            );
 277            let response = request.await?;
 278            Ok(response)
 279        });
 280
 281        async move { Ok(future.await?.boxed()) }.boxed()
 282    }
 283}
 284
 285impl LanguageModel for MistralLanguageModel {
 286    fn id(&self) -> LanguageModelId {
 287        self.id.clone()
 288    }
 289
 290    fn name(&self) -> LanguageModelName {
 291        LanguageModelName::from(self.model.display_name().to_string())
 292    }
 293
 294    fn provider_id(&self) -> LanguageModelProviderId {
 295        PROVIDER_ID
 296    }
 297
 298    fn provider_name(&self) -> LanguageModelProviderName {
 299        PROVIDER_NAME
 300    }
 301
 302    fn supports_tools(&self) -> bool {
 303        self.model.supports_tools()
 304    }
 305
 306    fn supports_streaming_tools(&self) -> bool {
 307        true
 308    }
 309
 310    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 311        self.model.supports_tools()
 312    }
 313
 314    fn supports_images(&self) -> bool {
 315        self.model.supports_images()
 316    }
 317
 318    fn telemetry_id(&self) -> String {
 319        format!("mistral/{}", self.model.id())
 320    }
 321
 322    fn max_token_count(&self) -> u64 {
 323        self.model.max_token_count()
 324    }
 325
 326    fn max_output_tokens(&self) -> Option<u64> {
 327        self.model.max_output_tokens()
 328    }
 329
 330    fn count_tokens(
 331        &self,
 332        request: LanguageModelRequest,
 333        cx: &App,
 334    ) -> BoxFuture<'static, Result<u64>> {
 335        cx.background_spawn(async move {
 336            let messages = request
 337                .messages
 338                .into_iter()
 339                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 340                    role: match message.role {
 341                        Role::User => "user".into(),
 342                        Role::Assistant => "assistant".into(),
 343                        Role::System => "system".into(),
 344                    },
 345                    content: Some(message.string_contents()),
 346                    name: None,
 347                    function_call: None,
 348                })
 349                .collect::<Vec<_>>();
 350
 351            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 352        })
 353        .boxed()
 354    }
 355
 356    fn stream_completion(
 357        &self,
 358        request: LanguageModelRequest,
 359        cx: &AsyncApp,
 360    ) -> BoxFuture<
 361        'static,
 362        Result<
 363            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 364            LanguageModelCompletionError,
 365        >,
 366    > {
 367        let (request, affinity) =
 368            into_mistral(request, self.model.clone(), self.max_output_tokens());
 369        let stream = self.stream_completion(request, affinity, cx);
 370
 371        async move {
 372            let stream = stream.await?;
 373            let mapper = MistralEventMapper::new();
 374            Ok(mapper.map_stream(stream).boxed())
 375        }
 376        .boxed()
 377    }
 378}
 379
 380pub fn into_mistral(
 381    request: LanguageModelRequest,
 382    model: mistral::Model,
 383    max_output_tokens: Option<u64>,
 384) -> (mistral::Request, Option<String>) {
 385    let stream = true;
 386
 387    let mut messages = Vec::new();
 388    for message in &request.messages {
 389        match message.role {
 390            Role::User => {
 391                let mut message_content = mistral::MessageContent::empty();
 392                for content in &message.content {
 393                    match content {
 394                        MessageContent::Text(text) => {
 395                            message_content
 396                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 397                        }
 398                        MessageContent::Image(image_content) => {
 399                            if model.supports_images() {
 400                                message_content.push_part(mistral::MessagePart::ImageUrl {
 401                                    image_url: image_content.to_base64_url(),
 402                                });
 403                            }
 404                        }
 405                        MessageContent::Thinking { text, .. } => {
 406                            if model.supports_thinking() {
 407                                message_content.push_part(mistral::MessagePart::Thinking {
 408                                    thinking: vec![mistral::ThinkingPart::Text {
 409                                        text: text.clone(),
 410                                    }],
 411                                });
 412                            }
 413                        }
 414                        MessageContent::RedactedThinking(_) => {}
 415                        MessageContent::ToolUse(_) => {
 416                            // Tool use is not supported in User messages for Mistral
 417                        }
 418                        MessageContent::ToolResult(tool_result) => {
 419                            let tool_content = match &tool_result.content {
 420                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 421                                LanguageModelToolResultContent::Image(_) => {
 422                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 423                                }
 424                            };
 425                            messages.push(mistral::RequestMessage::Tool {
 426                                content: tool_content,
 427                                tool_call_id: tool_result.tool_use_id.to_string(),
 428                            });
 429                        }
 430                    }
 431                }
 432                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 433                {
 434                    messages.push(mistral::RequestMessage::User {
 435                        content: message_content,
 436                    });
 437                }
 438            }
 439            Role::Assistant => {
 440                for content in &message.content {
 441                    match content {
 442                        MessageContent::Text(text) if text.is_empty() => {
 443                            // Mistral API returns a 400 if there's neither content nor tool_calls
 444                        }
 445                        MessageContent::Text(text) => {
 446                            messages.push(mistral::RequestMessage::Assistant {
 447                                content: Some(mistral::MessageContent::Plain {
 448                                    content: text.clone(),
 449                                }),
 450                                tool_calls: Vec::new(),
 451                            });
 452                        }
 453                        MessageContent::Thinking { text, .. } => {
 454                            if model.supports_thinking() {
 455                                messages.push(mistral::RequestMessage::Assistant {
 456                                    content: Some(mistral::MessageContent::Multipart {
 457                                        content: vec![mistral::MessagePart::Thinking {
 458                                            thinking: vec![mistral::ThinkingPart::Text {
 459                                                text: text.clone(),
 460                                            }],
 461                                        }],
 462                                    }),
 463                                    tool_calls: Vec::new(),
 464                                });
 465                            }
 466                        }
 467                        MessageContent::RedactedThinking(_) => {}
 468                        MessageContent::Image(_) => {}
 469                        MessageContent::ToolUse(tool_use) => {
 470                            let tool_call = mistral::ToolCall {
 471                                id: tool_use.id.to_string(),
 472                                content: mistral::ToolCallContent::Function {
 473                                    function: mistral::FunctionContent {
 474                                        name: tool_use.name.to_string(),
 475                                        arguments: serde_json::to_string(&tool_use.input)
 476                                            .unwrap_or_default(),
 477                                    },
 478                                },
 479                            };
 480
 481                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 482                                messages.last_mut()
 483                            {
 484                                tool_calls.push(tool_call);
 485                            } else {
 486                                messages.push(mistral::RequestMessage::Assistant {
 487                                    content: None,
 488                                    tool_calls: vec![tool_call],
 489                                });
 490                            }
 491                        }
 492                        MessageContent::ToolResult(_) => {
 493                            // Tool results are not supported in Assistant messages
 494                        }
 495                    }
 496                }
 497            }
 498            Role::System => {
 499                for content in &message.content {
 500                    match content {
 501                        MessageContent::Text(text) => {
 502                            messages.push(mistral::RequestMessage::System {
 503                                content: mistral::MessageContent::Plain {
 504                                    content: text.clone(),
 505                                },
 506                            });
 507                        }
 508                        MessageContent::Thinking { text, .. } => {
 509                            if model.supports_thinking() {
 510                                messages.push(mistral::RequestMessage::System {
 511                                    content: mistral::MessageContent::Multipart {
 512                                        content: vec![mistral::MessagePart::Thinking {
 513                                            thinking: vec![mistral::ThinkingPart::Text {
 514                                                text: text.clone(),
 515                                            }],
 516                                        }],
 517                                    },
 518                                });
 519                            }
 520                        }
 521                        MessageContent::RedactedThinking(_) => {}
 522                        MessageContent::Image(_)
 523                        | MessageContent::ToolUse(_)
 524                        | MessageContent::ToolResult(_) => {
 525                            // Images and tools are not supported in System messages
 526                        }
 527                    }
 528                }
 529            }
 530        }
 531    }
 532
 533    (
 534        mistral::Request {
 535            model: model.id().to_string(),
 536            messages,
 537            stream,
 538            stream_options: if stream {
 539                Some(mistral::StreamOptions {
 540                    stream_tool_calls: Some(true),
 541                })
 542            } else {
 543                None
 544            },
 545            max_tokens: max_output_tokens,
 546            temperature: request.temperature,
 547            response_format: None,
 548            tool_choice: match request.tool_choice {
 549                Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 550                    Some(mistral::ToolChoice::Auto)
 551                }
 552                Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 553                    Some(mistral::ToolChoice::Any)
 554                }
 555                Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 556                _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 557                _ => None,
 558            },
 559            parallel_tool_calls: if !request.tools.is_empty() {
 560                Some(false)
 561            } else {
 562                None
 563            },
 564            tools: request
 565                .tools
 566                .into_iter()
 567                .map(|tool| mistral::ToolDefinition::Function {
 568                    function: mistral::FunctionDefinition {
 569                        name: tool.name,
 570                        description: Some(tool.description),
 571                        parameters: Some(tool.input_schema),
 572                    },
 573                })
 574                .collect(),
 575        },
 576        request.thread_id,
 577    )
 578}
 579
 580pub struct MistralEventMapper {
 581    tool_calls_by_index: HashMap<usize, RawToolCall>,
 582}
 583
 584impl MistralEventMapper {
 585    pub fn new() -> Self {
 586        Self {
 587            tool_calls_by_index: HashMap::default(),
 588        }
 589    }
 590
 591    pub fn map_stream(
 592        mut self,
 593        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 594    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 595    {
 596        events.flat_map(move |event| {
 597            futures::stream::iter(match event {
 598                Ok(event) => self.map_event(event),
 599                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 600            })
 601        })
 602    }
 603
 604    pub fn map_event(
 605        &mut self,
 606        event: mistral::StreamResponse,
 607    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 608        let Some(choice) = event.choices.first() else {
 609            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 610                "Response contained no choices"
 611            )))];
 612        };
 613
 614        let mut events = Vec::new();
 615        if let Some(content) = choice.delta.content.as_ref() {
 616            match content {
 617                mistral::MessageContentDelta::Text(text) => {
 618                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 619                }
 620                mistral::MessageContentDelta::Parts(parts) => {
 621                    for part in parts {
 622                        match part {
 623                            mistral::MessagePart::Text { text } => {
 624                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 625                            }
 626                            mistral::MessagePart::Thinking { thinking } => {
 627                                for tp in thinking.iter().cloned() {
 628                                    match tp {
 629                                        mistral::ThinkingPart::Text { text } => {
 630                                            events.push(Ok(
 631                                                LanguageModelCompletionEvent::Thinking {
 632                                                    text,
 633                                                    signature: None,
 634                                                },
 635                                            ));
 636                                        }
 637                                    }
 638                                }
 639                            }
 640                            mistral::MessagePart::ImageUrl { .. } => {
 641                                // We currently don't emit a separate event for images in responses.
 642                            }
 643                        }
 644                    }
 645                }
 646            }
 647        }
 648
 649        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 650            for tool_call in tool_calls {
 651                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 652
 653                if let Some(tool_id) = tool_call.id.clone()
 654                    && !tool_id.is_empty()
 655                    && tool_id != "null"
 656                {
 657                    entry.id = tool_id;
 658                }
 659
 660                if let Some(function) = tool_call.function.as_ref() {
 661                    if let Some(name) = function.name.clone()
 662                        && !name.is_empty()
 663                    {
 664                        entry.name = name;
 665                    }
 666
 667                    if let Some(arguments) = function.arguments.clone() {
 668                        entry.arguments.push_str(&arguments);
 669                    }
 670                }
 671
 672                if !entry.id.is_empty() && !entry.name.is_empty() {
 673                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
 674                        &fix_streamed_json(&entry.arguments),
 675                    ) {
 676                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 677                            LanguageModelToolUse {
 678                                id: entry.id.clone().into(),
 679                                name: entry.name.as_str().into(),
 680                                is_input_complete: false,
 681                                input,
 682                                raw_input: entry.arguments.clone(),
 683                                thought_signature: None,
 684                            },
 685                        )));
 686                    }
 687                }
 688            }
 689        }
 690
 691        if let Some(usage) = event.usage {
 692            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 693                input_tokens: usage.prompt_tokens,
 694                output_tokens: usage.completion_tokens,
 695                cache_creation_input_tokens: 0,
 696                cache_read_input_tokens: 0,
 697            })));
 698        }
 699
 700        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 701            match finish_reason {
 702                "stop" => {
 703                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 704                }
 705                "tool_calls" => {
 706                    events.extend(self.process_tool_calls());
 707                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 708                }
 709                unexpected => {
 710                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 711                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 712                }
 713            }
 714        }
 715
 716        events
 717    }
 718
 719    fn process_tool_calls(
 720        &mut self,
 721    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 722        let mut results = Vec::new();
 723
 724        for (_, tool_call) in self.tool_calls_by_index.drain() {
 725            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 726                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 727                    "Received incomplete tool call: missing id or name"
 728                ))));
 729                continue;
 730            }
 731
 732            match parse_tool_arguments(&tool_call.arguments) {
 733                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 734                    LanguageModelToolUse {
 735                        id: tool_call.id.into(),
 736                        name: tool_call.name.into(),
 737                        is_input_complete: true,
 738                        input,
 739                        raw_input: tool_call.arguments,
 740                        thought_signature: None,
 741                    },
 742                ))),
 743                Err(error) => {
 744                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 745                        id: tool_call.id.into(),
 746                        tool_name: tool_call.name.into(),
 747                        raw_input: tool_call.arguments.into(),
 748                        json_parse_error: error.to_string(),
 749                    }))
 750                }
 751            }
 752        }
 753
 754        results
 755    }
 756}
 757
 758#[derive(Default)]
 759struct RawToolCall {
 760    id: String,
 761    name: String,
 762    arguments: String,
 763}
 764
 765struct ConfigurationView {
 766    api_key_editor: Entity<InputField>,
 767    state: Entity<State>,
 768    load_credentials_task: Option<Task<()>>,
 769}
 770
 771impl ConfigurationView {
 772    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 773        let api_key_editor =
 774            cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
 775
 776        cx.observe(&state, |_, _, cx| {
 777            cx.notify();
 778        })
 779        .detach();
 780
 781        let load_credentials_task = Some(cx.spawn_in(window, {
 782            let state = state.clone();
 783            async move |this, cx| {
 784                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 785                    // We don't log an error, because "not signed in" is also an error.
 786                    let _ = task.await;
 787                }
 788
 789                this.update(cx, |this, cx| {
 790                    this.load_credentials_task = None;
 791                    cx.notify();
 792                })
 793                .log_err();
 794            }
 795        }));
 796
 797        Self {
 798            api_key_editor,
 799            state,
 800            load_credentials_task,
 801        }
 802    }
 803
 804    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 805        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 806        if api_key.is_empty() {
 807            return;
 808        }
 809
 810        // url changes can cause the editor to be displayed again
 811        self.api_key_editor
 812            .update(cx, |editor, cx| editor.set_text("", window, cx));
 813
 814        let state = self.state.clone();
 815        cx.spawn_in(window, async move |_, cx| {
 816            state
 817                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 818                .await
 819        })
 820        .detach_and_log_err(cx);
 821    }
 822
 823    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 824        self.api_key_editor
 825            .update(cx, |editor, cx| editor.set_text("", window, cx));
 826
 827        let state = self.state.clone();
 828        cx.spawn_in(window, async move |_, cx| {
 829            state
 830                .update(cx, |state, cx| state.set_api_key(None, cx))
 831                .await
 832        })
 833        .detach_and_log_err(cx);
 834    }
 835
 836    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
 837        !self.state.read(cx).is_authenticated()
 838    }
 839}
 840
 841impl Render for ConfigurationView {
 842    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 843        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 844        let configured_card_label = if env_var_set {
 845            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
 846        } else {
 847            let api_url = MistralLanguageModelProvider::api_url(cx);
 848            if api_url == MISTRAL_API_URL {
 849                "API key configured".to_string()
 850            } else {
 851                format!("API key configured for {}", api_url)
 852            }
 853        };
 854
 855        if self.load_credentials_task.is_some() {
 856            div().child(Label::new("Loading credentials...")).into_any()
 857        } else if self.should_render_api_key_editor(cx) {
 858            v_flex()
 859                .size_full()
 860                .on_action(cx.listener(Self::save_api_key))
 861                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 862                .child(
 863                    List::new()
 864                        .child(
 865                            ListBulletItem::new("")
 866                                .child(Label::new("Create one by visiting"))
 867                                .child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
 868                        )
 869                        .child(
 870                            ListBulletItem::new("Ensure your Mistral account has credits")
 871                        )
 872                        .child(
 873                            ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
 874                        ),
 875                )
 876                .child(self.api_key_editor.clone())
 877                .child(
 878                    Label::new(
 879                        format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 880                    )
 881                    .size(LabelSize::Small).color(Color::Muted),
 882                )
 883                .into_any()
 884        } else {
 885            v_flex()
 886                .size_full()
 887                .gap_1()
 888                .child(
 889                    ConfiguredApiCard::new(configured_card_label)
 890                        .disabled(env_var_set)
 891                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 892                        .when(env_var_set, |this| {
 893                            this.tooltip_label(format!(
 894                                "To reset your API key, \
 895                                unset the {API_KEY_ENV_VAR_NAME} environment variable."
 896                            ))
 897                        }),
 898                )
 899                .into_any()
 900        }
 901    }
 902}
 903
 904#[cfg(test)]
 905mod tests {
 906    use super::*;
 907    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
 908
 909    fn tool_call_chunk(
 910        id: Option<&str>,
 911        name: Option<&str>,
 912        arguments: Option<&str>,
 913        finish_reason: Option<&str>,
 914    ) -> mistral::StreamResponse {
 915        mistral::StreamResponse {
 916            id: "resp".into(),
 917            object: "chat.completion.chunk".into(),
 918            created: 0,
 919            model: "test".into(),
 920            choices: vec![mistral::StreamChoice {
 921                index: 0,
 922                delta: mistral::StreamDelta {
 923                    role: None,
 924                    content: None,
 925                    tool_calls: if finish_reason.is_some() {
 926                        None
 927                    } else {
 928                        Some(vec![mistral::ToolCallChunk {
 929                            index: 0,
 930                            id: id.map(Into::into),
 931                            function: Some(mistral::FunctionChunk {
 932                                name: name.map(Into::into),
 933                                arguments: arguments.map(Into::into),
 934                            }),
 935                        }])
 936                    },
 937                },
 938                finish_reason: finish_reason.map(Into::into),
 939            }],
 940            usage: None,
 941        }
 942    }
 943
 944    #[test]
 945    fn test_streaming_tool_call_ignores_null_id() {
 946        // Mistral's streaming API sometimes sends `"id": "null"` in continuation chunks.
 947        let mut mapper = MistralEventMapper::new();
 948
 949        mapper.map_event(tool_call_chunk(
 950            Some("real_id_123"),
 951            Some("read_file"),
 952            Some("{\"path\":"),
 953            None,
 954        ));
 955        mapper.map_event(tool_call_chunk(
 956            Some("null"),
 957            None,
 958            Some("\"a.txt\"}"),
 959            None,
 960        ));
 961        let events = mapper.map_event(tool_call_chunk(None, None, None, Some("tool_calls")));
 962
 963        let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] else {
 964            panic!("Expected first event to be ToolUse, got: {:?}", events[0]);
 965        };
 966
 967        assert_eq!(tool_use.id.to_string(), "real_id_123");
 968        assert_eq!(tool_use.name.as_ref(), "read_file");
 969        assert_eq!(tool_use.input, serde_json::json!({"path": "a.txt"}));
 970    }
 971
 972    #[test]
 973    fn test_into_mistral_basic_conversion() {
 974        let request = LanguageModelRequest {
 975            messages: vec![
 976                LanguageModelRequestMessage {
 977                    role: Role::System,
 978                    content: vec![MessageContent::Text("System prompt".into())],
 979                    cache: false,
 980                    reasoning_details: None,
 981                },
 982                LanguageModelRequestMessage {
 983                    role: Role::User,
 984                    content: vec![MessageContent::Text("Hello".into())],
 985                    cache: false,
 986                    reasoning_details: None,
 987                },
 988                // should skip empty assistant messages
 989                LanguageModelRequestMessage {
 990                    role: Role::Assistant,
 991                    content: vec![MessageContent::Text("".into())],
 992                    cache: false,
 993                    reasoning_details: None,
 994                },
 995            ],
 996            temperature: Some(0.5),
 997            tools: vec![],
 998            tool_choice: None,
 999            thread_id: Some("abcdef".into()),
1000            prompt_id: None,
1001            intent: None,
1002            stop: vec![],
1003            thinking_allowed: true,
1004            thinking_effort: None,
1005            speed: Default::default(),
1006        };
1007
1008        let (mistral_request, affinity) =
1009            into_mistral(request, mistral::Model::MistralSmallLatest, None);
1010
1011        assert_eq!(mistral_request.model, "mistral-small-latest");
1012        assert_eq!(mistral_request.temperature, Some(0.5));
1013        assert_eq!(mistral_request.messages.len(), 2);
1014        assert!(mistral_request.stream);
1015        assert_eq!(affinity, Some("abcdef".into()));
1016    }
1017
1018    #[test]
1019    fn test_into_mistral_with_image() {
1020        let request = LanguageModelRequest {
1021            messages: vec![LanguageModelRequestMessage {
1022                role: Role::User,
1023                content: vec![
1024                    MessageContent::Text("What's in this image?".into()),
1025                    MessageContent::Image(LanguageModelImage {
1026                        source: "base64data".into(),
1027                        size: None,
1028                    }),
1029                ],
1030                cache: false,
1031                reasoning_details: None,
1032            }],
1033            tools: vec![],
1034            tool_choice: None,
1035            temperature: None,
1036            thread_id: None,
1037            prompt_id: None,
1038            intent: None,
1039            stop: vec![],
1040            thinking_allowed: true,
1041            thinking_effort: None,
1042            speed: None,
1043        };
1044
1045        let (mistral_request, _) = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1046
1047        assert_eq!(mistral_request.messages.len(), 1);
1048        assert!(matches!(
1049            &mistral_request.messages[0],
1050            mistral::RequestMessage::User {
1051                content: mistral::MessageContent::Multipart { .. }
1052            }
1053        ));
1054
1055        if let mistral::RequestMessage::User {
1056            content: mistral::MessageContent::Multipart { content },
1057        } = &mistral_request.messages[0]
1058        {
1059            assert_eq!(content.len(), 2);
1060            assert!(matches!(
1061                &content[0],
1062                mistral::MessagePart::Text { text } if text == "What's in this image?"
1063            ));
1064            assert!(matches!(
1065                &content[1],
1066                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1067            ));
1068        }
1069    }
1070}