1use anyhow::{Result, anyhow};
   2use collections::BTreeMap;
   3use fs::Fs;
   4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
   5use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
   6use http_client::HttpClient;
   7use language_model::{
   8    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
   9    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  10    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  11    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  12    RateLimiter, Role, StopReason, TokenUsage,
  13};
  14use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
  15pub use settings::MistralAvailableModel as AvailableModel;
  16use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
  17use std::collections::HashMap;
  18use std::pin::Pin;
  19use std::str::FromStr;
  20use std::sync::{Arc, LazyLock};
  21use strum::IntoEnumIterator;
  22use ui::{Icon, IconName, List, Tooltip, prelude::*};
  23use ui_input::SingleLineInput;
  24use util::{ResultExt, truncate_and_trailoff};
  25use zed_env_vars::{EnvVar, env_var};
  26
  27use crate::{api_key::ApiKeyState, ui::InstructionListItem};
  28
  29const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  30const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  31
  32const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
  33static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  34
  35const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
  36static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
  37
  38#[derive(Default, Clone, Debug, PartialEq)]
  39pub struct MistralSettings {
  40    pub api_url: String,
  41    pub available_models: Vec<AvailableModel>,
  42}
  43
  44pub struct MistralLanguageModelProvider {
  45    http_client: Arc<dyn HttpClient>,
  46    state: Entity<State>,
  47}
  48
  49pub struct State {
  50    api_key_state: ApiKeyState,
  51    codestral_api_key_state: ApiKeyState,
  52}
  53
  54impl State {
  55    const fn is_authenticated(&self) -> bool {
  56        self.api_key_state.has_key()
  57    }
  58
  59    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  60        let api_url = MistralLanguageModelProvider::api_url(cx);
  61        self.api_key_state
  62            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  63    }
  64
  65    fn set_codestral_api_key(
  66        &mut self,
  67        api_key: Option<String>,
  68        cx: &mut Context<Self>,
  69    ) -> Task<Result<()>> {
  70        self.codestral_api_key_state.store(
  71            CODESTRAL_API_URL.into(),
  72            api_key,
  73            |this| &mut this.codestral_api_key_state,
  74            cx,
  75        )
  76    }
  77
  78    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  79        let api_url = MistralLanguageModelProvider::api_url(cx);
  80        self.api_key_state.load_if_needed(
  81            api_url,
  82            &API_KEY_ENV_VAR,
  83            |this| &mut this.api_key_state,
  84            cx,
  85        )
  86    }
  87
  88    fn authenticate_codestral(
  89        &mut self,
  90        cx: &mut Context<Self>,
  91    ) -> Task<Result<(), AuthenticateError>> {
  92        self.codestral_api_key_state.load_if_needed(
  93            CODESTRAL_API_URL.into(),
  94            &CODESTRAL_API_KEY_ENV_VAR,
  95            |this| &mut this.codestral_api_key_state,
  96            cx,
  97        )
  98    }
  99}
 100
 101struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
 102
 103impl Global for GlobalMistralLanguageModelProvider {}
 104
 105impl MistralLanguageModelProvider {
 106    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
 107        cx.try_global::<GlobalMistralLanguageModelProvider>()
 108            .map(|this| &this.0)
 109    }
 110
 111    pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
 112        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
 113            return this.0.clone();
 114        }
 115        let state = cx.new(|cx| {
 116            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 117                let api_url = Self::api_url(cx);
 118                this.api_key_state.handle_url_change(
 119                    api_url,
 120                    &API_KEY_ENV_VAR,
 121                    |this| &mut this.api_key_state,
 122                    cx,
 123                );
 124                cx.notify();
 125            })
 126            .detach();
 127            State {
 128                api_key_state: ApiKeyState::new(Self::api_url(cx)),
 129                codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
 130            }
 131        });
 132
 133        let this = Arc::new(Self { http_client, state });
 134        cx.set_global(GlobalMistralLanguageModelProvider(this));
 135        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
 136    }
 137
 138    pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 139        self.state
 140            .update(cx, |state, cx| state.authenticate_codestral(cx))
 141    }
 142
 143    pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
 144        self.state.read(cx).codestral_api_key_state.key(url)
 145    }
 146
 147    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 148        Arc::new(MistralLanguageModel {
 149            id: LanguageModelId::from(model.id().to_string()),
 150            model,
 151            state: self.state.clone(),
 152            http_client: self.http_client.clone(),
 153            request_limiter: RateLimiter::new(4),
 154        })
 155    }
 156
 157    fn settings(cx: &App) -> &MistralSettings {
 158        &crate::AllLanguageModelSettings::get_global(cx).mistral
 159    }
 160
 161    fn api_url(cx: &App) -> SharedString {
 162        let api_url = &Self::settings(cx).api_url;
 163        if api_url.is_empty() {
 164            mistral::MISTRAL_API_URL.into()
 165        } else {
 166            SharedString::new(api_url.as_str())
 167        }
 168    }
 169}
 170
 171impl LanguageModelProviderState for MistralLanguageModelProvider {
 172    type ObservableEntity = State;
 173
 174    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 175        Some(self.state.clone())
 176    }
 177}
 178
 179impl LanguageModelProvider for MistralLanguageModelProvider {
 180    fn id(&self) -> LanguageModelProviderId {
 181        PROVIDER_ID
 182    }
 183
 184    fn name(&self) -> LanguageModelProviderName {
 185        PROVIDER_NAME
 186    }
 187
 188    fn icon(&self) -> IconName {
 189        IconName::AiMistral
 190    }
 191
 192    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 193        Some(self.create_language_model(mistral::Model::default()))
 194    }
 195
 196    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 197        Some(self.create_language_model(mistral::Model::default_fast()))
 198    }
 199
 200    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 201        let mut models = BTreeMap::default();
 202
 203        // Add base models from mistral::Model::iter()
 204        for model in mistral::Model::iter() {
 205            if !matches!(model, mistral::Model::Custom { .. }) {
 206                models.insert(model.id().to_string(), model);
 207            }
 208        }
 209
 210        // Override with available models from settings
 211        for model in &Self::settings(cx).available_models {
 212            models.insert(
 213                model.name.clone(),
 214                mistral::Model::Custom {
 215                    name: model.name.clone(),
 216                    display_name: model.display_name.clone(),
 217                    max_tokens: model.max_tokens,
 218                    max_output_tokens: model.max_output_tokens,
 219                    max_completion_tokens: model.max_completion_tokens,
 220                    supports_tools: model.supports_tools,
 221                    supports_images: model.supports_images,
 222                    supports_thinking: model.supports_thinking,
 223                },
 224            );
 225        }
 226
 227        models
 228            .into_values()
 229            .map(|model| {
 230                Arc::new(MistralLanguageModel {
 231                    id: LanguageModelId::from(model.id().to_string()),
 232                    model,
 233                    state: self.state.clone(),
 234                    http_client: self.http_client.clone(),
 235                    request_limiter: RateLimiter::new(4),
 236                }) as Arc<dyn LanguageModel>
 237            })
 238            .collect()
 239    }
 240
 241    fn is_authenticated(&self, cx: &App) -> bool {
 242        self.state.read(cx).is_authenticated()
 243    }
 244
 245    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 246        self.state.update(cx, |state, cx| state.authenticate(cx))
 247    }
 248
 249    fn configuration_view(
 250        &self,
 251        _target_agent: language_model::ConfigurationViewTargetAgent,
 252        window: &mut Window,
 253        cx: &mut App,
 254    ) -> AnyView {
 255        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 256            .into()
 257    }
 258
 259    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 260        self.state
 261            .update(cx, |state, cx| state.set_api_key(None, cx))
 262    }
 263}
 264
 265pub struct MistralLanguageModel {
 266    id: LanguageModelId,
 267    model: mistral::Model,
 268    state: Entity<State>,
 269    http_client: Arc<dyn HttpClient>,
 270    request_limiter: RateLimiter,
 271}
 272
 273impl MistralLanguageModel {
 274    fn stream_completion(
 275        &self,
 276        request: mistral::Request,
 277        cx: &AsyncApp,
 278    ) -> BoxFuture<
 279        'static,
 280        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 281    > {
 282        let http_client = self.http_client.clone();
 283
 284        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
 285            let api_url = MistralLanguageModelProvider::api_url(cx);
 286            (state.api_key_state.key(&api_url), api_url)
 287        }) else {
 288            return future::ready(Err(anyhow!("App state dropped"))).boxed();
 289        };
 290
 291        let future = self.request_limiter.stream(async move {
 292            let Some(api_key) = api_key else {
 293                return Err(LanguageModelCompletionError::NoApiKey {
 294                    provider: PROVIDER_NAME,
 295                });
 296            };
 297            let request =
 298                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 299            let response = request.await?;
 300            Ok(response)
 301        });
 302
 303        async move { Ok(future.await?.boxed()) }.boxed()
 304    }
 305}
 306
 307impl LanguageModel for MistralLanguageModel {
 308    fn id(&self) -> LanguageModelId {
 309        self.id.clone()
 310    }
 311
 312    fn name(&self) -> LanguageModelName {
 313        LanguageModelName::from(self.model.display_name().to_string())
 314    }
 315
 316    fn provider_id(&self) -> LanguageModelProviderId {
 317        PROVIDER_ID
 318    }
 319
 320    fn provider_name(&self) -> LanguageModelProviderName {
 321        PROVIDER_NAME
 322    }
 323
 324    fn supports_tools(&self) -> bool {
 325        self.model.supports_tools()
 326    }
 327
 328    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 329        self.model.supports_tools()
 330    }
 331
 332    fn supports_images(&self) -> bool {
 333        self.model.supports_images()
 334    }
 335
 336    fn telemetry_id(&self) -> String {
 337        format!("mistral/{}", self.model.id())
 338    }
 339
 340    fn max_token_count(&self) -> u64 {
 341        self.model.max_token_count()
 342    }
 343
 344    fn max_output_tokens(&self) -> Option<u64> {
 345        self.model.max_output_tokens()
 346    }
 347
 348    fn count_tokens(
 349        &self,
 350        request: LanguageModelRequest,
 351        cx: &App,
 352    ) -> BoxFuture<'static, Result<u64>> {
 353        cx.background_spawn(async move {
 354            let messages = request
 355                .messages
 356                .into_iter()
 357                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 358                    role: match message.role {
 359                        Role::User => "user".into(),
 360                        Role::Assistant => "assistant".into(),
 361                        Role::System => "system".into(),
 362                    },
 363                    content: Some(message.string_contents()),
 364                    name: None,
 365                    function_call: None,
 366                })
 367                .collect::<Vec<_>>();
 368
 369            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 370        })
 371        .boxed()
 372    }
 373
 374    fn stream_completion(
 375        &self,
 376        request: LanguageModelRequest,
 377        cx: &AsyncApp,
 378    ) -> BoxFuture<
 379        'static,
 380        Result<
 381            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 382            LanguageModelCompletionError,
 383        >,
 384    > {
 385        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
 386        let stream = self.stream_completion(request, cx);
 387
 388        async move {
 389            let stream = stream.await?;
 390            let mapper = MistralEventMapper::new();
 391            Ok(mapper.map_stream(stream).boxed())
 392        }
 393        .boxed()
 394    }
 395}
 396
 397pub fn into_mistral(
 398    request: LanguageModelRequest,
 399    model: mistral::Model,
 400    max_output_tokens: Option<u64>,
 401) -> mistral::Request {
 402    let stream = true;
 403
 404    let mut messages = Vec::new();
 405    for message in &request.messages {
 406        match message.role {
 407            Role::User => {
 408                let mut message_content = mistral::MessageContent::empty();
 409                for content in &message.content {
 410                    match content {
 411                        MessageContent::Text(text) => {
 412                            message_content
 413                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 414                        }
 415                        MessageContent::Image(image_content) => {
 416                            if model.supports_images() {
 417                                message_content.push_part(mistral::MessagePart::ImageUrl {
 418                                    image_url: image_content.to_base64_url(),
 419                                });
 420                            }
 421                        }
 422                        MessageContent::Thinking { text, .. } => {
 423                            if model.supports_thinking() {
 424                                message_content.push_part(mistral::MessagePart::Thinking {
 425                                    thinking: vec![mistral::ThinkingPart::Text {
 426                                        text: text.clone(),
 427                                    }],
 428                                });
 429                            }
 430                        }
 431                        MessageContent::RedactedThinking(_) => {}
 432                        MessageContent::ToolUse(_) => {
 433                            // Tool use is not supported in User messages for Mistral
 434                        }
 435                        MessageContent::ToolResult(tool_result) => {
 436                            let tool_content = match &tool_result.content {
 437                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 438                                LanguageModelToolResultContent::Image(_) => {
 439                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 440                                }
 441                            };
 442                            messages.push(mistral::RequestMessage::Tool {
 443                                content: tool_content,
 444                                tool_call_id: tool_result.tool_use_id.to_string(),
 445                            });
 446                        }
 447                    }
 448                }
 449                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 450                {
 451                    messages.push(mistral::RequestMessage::User {
 452                        content: message_content,
 453                    });
 454                }
 455            }
 456            Role::Assistant => {
 457                for content in &message.content {
 458                    match content {
 459                        MessageContent::Text(text) => {
 460                            messages.push(mistral::RequestMessage::Assistant {
 461                                content: Some(mistral::MessageContent::Plain {
 462                                    content: text.clone(),
 463                                }),
 464                                tool_calls: Vec::new(),
 465                            });
 466                        }
 467                        MessageContent::Thinking { text, .. } => {
 468                            if model.supports_thinking() {
 469                                messages.push(mistral::RequestMessage::Assistant {
 470                                    content: Some(mistral::MessageContent::Multipart {
 471                                        content: vec![mistral::MessagePart::Thinking {
 472                                            thinking: vec![mistral::ThinkingPart::Text {
 473                                                text: text.clone(),
 474                                            }],
 475                                        }],
 476                                    }),
 477                                    tool_calls: Vec::new(),
 478                                });
 479                            }
 480                        }
 481                        MessageContent::RedactedThinking(_) => {}
 482                        MessageContent::Image(_) => {}
 483                        MessageContent::ToolUse(tool_use) => {
 484                            let tool_call = mistral::ToolCall {
 485                                id: tool_use.id.to_string(),
 486                                content: mistral::ToolCallContent::Function {
 487                                    function: mistral::FunctionContent {
 488                                        name: tool_use.name.to_string(),
 489                                        arguments: serde_json::to_string(&tool_use.input)
 490                                            .unwrap_or_default(),
 491                                    },
 492                                },
 493                            };
 494
 495                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 496                                messages.last_mut()
 497                            {
 498                                tool_calls.push(tool_call);
 499                            } else {
 500                                messages.push(mistral::RequestMessage::Assistant {
 501                                    content: None,
 502                                    tool_calls: vec![tool_call],
 503                                });
 504                            }
 505                        }
 506                        MessageContent::ToolResult(_) => {
 507                            // Tool results are not supported in Assistant messages
 508                        }
 509                    }
 510                }
 511            }
 512            Role::System => {
 513                for content in &message.content {
 514                    match content {
 515                        MessageContent::Text(text) => {
 516                            messages.push(mistral::RequestMessage::System {
 517                                content: mistral::MessageContent::Plain {
 518                                    content: text.clone(),
 519                                },
 520                            });
 521                        }
 522                        MessageContent::Thinking { text, .. } => {
 523                            if model.supports_thinking() {
 524                                messages.push(mistral::RequestMessage::System {
 525                                    content: mistral::MessageContent::Multipart {
 526                                        content: vec![mistral::MessagePart::Thinking {
 527                                            thinking: vec![mistral::ThinkingPart::Text {
 528                                                text: text.clone(),
 529                                            }],
 530                                        }],
 531                                    },
 532                                });
 533                            }
 534                        }
 535                        MessageContent::RedactedThinking(_) => {}
 536                        MessageContent::Image(_)
 537                        | MessageContent::ToolUse(_)
 538                        | MessageContent::ToolResult(_) => {
 539                            // Images and tools are not supported in System messages
 540                        }
 541                    }
 542                }
 543            }
 544        }
 545    }
 546
 547    mistral::Request {
 548        model: model.id().to_string(),
 549        messages,
 550        stream,
 551        max_tokens: max_output_tokens,
 552        temperature: request.temperature,
 553        response_format: None,
 554        tool_choice: match request.tool_choice {
 555            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 556                Some(mistral::ToolChoice::Auto)
 557            }
 558            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 559                Some(mistral::ToolChoice::Any)
 560            }
 561            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 562            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 563            _ => None,
 564        },
 565        parallel_tool_calls: if !request.tools.is_empty() {
 566            Some(false)
 567        } else {
 568            None
 569        },
 570        tools: request
 571            .tools
 572            .into_iter()
 573            .map(|tool| mistral::ToolDefinition::Function {
 574                function: mistral::FunctionDefinition {
 575                    name: tool.name,
 576                    description: Some(tool.description),
 577                    parameters: Some(tool.input_schema),
 578                },
 579            })
 580            .collect(),
 581    }
 582}
 583
 584pub struct MistralEventMapper {
 585    tool_calls_by_index: HashMap<usize, RawToolCall>,
 586}
 587
 588impl MistralEventMapper {
 589    pub fn new() -> Self {
 590        Self {
 591            tool_calls_by_index: HashMap::default(),
 592        }
 593    }
 594
 595    pub fn map_stream(
 596        mut self,
 597        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 598    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 599    {
 600        events.flat_map(move |event| {
 601            futures::stream::iter(match event {
 602                Ok(event) => self.map_event(event),
 603                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 604            })
 605        })
 606    }
 607
 608    pub fn map_event(
 609        &mut self,
 610        event: mistral::StreamResponse,
 611    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 612        let Some(choice) = event.choices.first() else {
 613            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 614                "Response contained no choices"
 615            )))];
 616        };
 617
 618        let mut events = Vec::new();
 619        if let Some(content) = choice.delta.content.as_ref() {
 620            match content {
 621                mistral::MessageContentDelta::Text(text) => {
 622                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 623                }
 624                mistral::MessageContentDelta::Parts(parts) => {
 625                    for part in parts {
 626                        match part {
 627                            mistral::MessagePart::Text { text } => {
 628                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 629                            }
 630                            mistral::MessagePart::Thinking { thinking } => {
 631                                for tp in thinking.iter().cloned() {
 632                                    match tp {
 633                                        mistral::ThinkingPart::Text { text } => {
 634                                            events.push(Ok(
 635                                                LanguageModelCompletionEvent::Thinking {
 636                                                    text,
 637                                                    signature: None,
 638                                                },
 639                                            ));
 640                                        }
 641                                    }
 642                                }
 643                            }
 644                            mistral::MessagePart::ImageUrl { .. } => {
 645                                // We currently don't emit a separate event for images in responses.
 646                            }
 647                        }
 648                    }
 649                }
 650            }
 651        }
 652
 653        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 654            for tool_call in tool_calls {
 655                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 656
 657                if let Some(tool_id) = tool_call.id.clone() {
 658                    entry.id = tool_id;
 659                }
 660
 661                if let Some(function) = tool_call.function.as_ref() {
 662                    if let Some(name) = function.name.clone() {
 663                        entry.name = name;
 664                    }
 665
 666                    if let Some(arguments) = function.arguments.clone() {
 667                        entry.arguments.push_str(&arguments);
 668                    }
 669                }
 670            }
 671        }
 672
 673        if let Some(usage) = event.usage {
 674            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 675                input_tokens: usage.prompt_tokens,
 676                output_tokens: usage.completion_tokens,
 677                cache_creation_input_tokens: 0,
 678                cache_read_input_tokens: 0,
 679            })));
 680        }
 681
 682        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 683            match finish_reason {
 684                "stop" => {
 685                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 686                }
 687                "tool_calls" => {
 688                    events.extend(self.process_tool_calls());
 689                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 690                }
 691                unexpected => {
 692                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 693                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 694                }
 695            }
 696        }
 697
 698        events
 699    }
 700
 701    fn process_tool_calls(
 702        &mut self,
 703    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 704        let mut results = Vec::new();
 705
 706        for (_, tool_call) in self.tool_calls_by_index.drain() {
 707            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 708                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 709                    "Received incomplete tool call: missing id or name"
 710                ))));
 711                continue;
 712            }
 713
 714            match serde_json::Value::from_str(&tool_call.arguments) {
 715                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 716                    LanguageModelToolUse {
 717                        id: tool_call.id.into(),
 718                        name: tool_call.name.into(),
 719                        is_input_complete: true,
 720                        input,
 721                        raw_input: tool_call.arguments,
 722                    },
 723                ))),
 724                Err(error) => {
 725                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 726                        id: tool_call.id.into(),
 727                        tool_name: tool_call.name.into(),
 728                        raw_input: tool_call.arguments.into(),
 729                        json_parse_error: error.to_string(),
 730                    }))
 731                }
 732            }
 733        }
 734
 735        results
 736    }
 737}
 738
 739#[derive(Default)]
 740struct RawToolCall {
 741    id: String,
 742    name: String,
 743    arguments: String,
 744}
 745
 746struct ConfigurationView {
 747    api_key_editor: Entity<SingleLineInput>,
 748    codestral_api_key_editor: Entity<SingleLineInput>,
 749    state: Entity<State>,
 750    load_credentials_task: Option<Task<()>>,
 751}
 752
 753impl ConfigurationView {
 754    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 755        let api_key_editor =
 756            cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
 757        let codestral_api_key_editor =
 758            cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
 759
 760        cx.observe(&state, |_, _, cx| {
 761            cx.notify();
 762        })
 763        .detach();
 764
 765        let load_credentials_task = Some(cx.spawn_in(window, {
 766            let state = state.clone();
 767            async move |this, cx| {
 768                if let Some(task) = state
 769                    .update(cx, |state, cx| state.authenticate(cx))
 770                    .log_err()
 771                {
 772                    // We don't log an error, because "not signed in" is also an error.
 773                    let _ = task.await;
 774                }
 775                if let Some(task) = state
 776                    .update(cx, |state, cx| state.authenticate_codestral(cx))
 777                    .log_err()
 778                {
 779                    let _ = task.await;
 780                }
 781
 782                this.update(cx, |this, cx| {
 783                    this.load_credentials_task = None;
 784                    cx.notify();
 785                })
 786                .log_err();
 787            }
 788        }));
 789
 790        Self {
 791            api_key_editor,
 792            codestral_api_key_editor,
 793            state,
 794            load_credentials_task,
 795        }
 796    }
 797
 798    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 799        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 800        if api_key.is_empty() {
 801            return;
 802        }
 803
 804        // url changes can cause the editor to be displayed again
 805        self.api_key_editor
 806            .update(cx, |editor, cx| editor.set_text("", window, cx));
 807
 808        let state = self.state.clone();
 809        cx.spawn_in(window, async move |_, cx| {
 810            state
 811                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
 812                .await
 813        })
 814        .detach_and_log_err(cx);
 815    }
 816
 817    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 818        self.api_key_editor
 819            .update(cx, |editor, cx| editor.set_text("", window, cx));
 820
 821        let state = self.state.clone();
 822        cx.spawn_in(window, async move |_, cx| {
 823            state
 824                .update(cx, |state, cx| state.set_api_key(None, cx))?
 825                .await
 826        })
 827        .detach_and_log_err(cx);
 828    }
 829
 830    fn save_codestral_api_key(
 831        &mut self,
 832        _: &menu::Confirm,
 833        window: &mut Window,
 834        cx: &mut Context<Self>,
 835    ) {
 836        let api_key = self
 837            .codestral_api_key_editor
 838            .read(cx)
 839            .text(cx)
 840            .trim()
 841            .to_string();
 842        if api_key.is_empty() {
 843            return;
 844        }
 845
 846        // url changes can cause the editor to be displayed again
 847        self.codestral_api_key_editor
 848            .update(cx, |editor, cx| editor.set_text("", window, cx));
 849
 850        let state = self.state.clone();
 851        cx.spawn_in(window, async move |_, cx| {
 852            state
 853                .update(cx, |state, cx| {
 854                    state.set_codestral_api_key(Some(api_key), cx)
 855                })?
 856                .await?;
 857            cx.update(|_window, cx| {
 858                set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
 859            })
 860        })
 861        .detach_and_log_err(cx);
 862    }
 863
 864    fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 865        self.codestral_api_key_editor
 866            .update(cx, |editor, cx| editor.set_text("", window, cx));
 867
 868        let state = self.state.clone();
 869        cx.spawn_in(window, async move |_, cx| {
 870            state
 871                .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
 872                .await?;
 873            cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
 874        })
 875        .detach_and_log_err(cx);
 876    }
 877
 878    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
 879        !self.state.read(cx).is_authenticated()
 880    }
 881
 882    fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
 883        let key_state = &self.state.read(cx).codestral_api_key_state;
 884        let should_show_editor = !key_state.has_key();
 885        let env_var_set = key_state.is_from_env_var();
 886        if should_show_editor {
 887            v_flex()
 888                .id("codestral")
 889                .size_full()
 890                .mt_2()
 891                .on_action(cx.listener(Self::save_codestral_api_key))
 892                .child(Label::new(
 893                    "To use Codestral as an edit prediction provider, \
 894                    you need to add a Codestral-specific API key. Follow these steps:",
 895                ))
 896                .child(
 897                    List::new()
 898                        .child(InstructionListItem::new(
 899                            "Create one by visiting",
 900                            Some("the Codestral section of Mistral's console"),
 901                            Some("https://console.mistral.ai/codestral"),
 902                        ))
 903                        .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
 904                )
 905                .child(self.codestral_api_key_editor.clone())
 906                .child(
 907                    Label::new(
 908                        format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 909                    )
 910                    .size(LabelSize::Small).color(Color::Muted),
 911                ).into_any()
 912        } else {
 913            h_flex()
 914                .id("codestral")
 915                .mt_2()
 916                .p_1()
 917                .justify_between()
 918                .rounded_md()
 919                .border_1()
 920                .border_color(cx.theme().colors().border)
 921                .bg(cx.theme().colors().background)
 922                .child(
 923                    h_flex()
 924                        .gap_1()
 925                        .child(Icon::new(IconName::Check).color(Color::Success))
 926                        .child(Label::new(if env_var_set {
 927                            format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
 928                        } else {
 929                            "Codestral API key configured".to_string()
 930                        })),
 931                )
 932                .child(
 933                    Button::new("reset-key", "Reset Key")
 934                        .label_size(LabelSize::Small)
 935                        .icon(Some(IconName::Trash))
 936                        .icon_size(IconSize::Small)
 937                        .icon_position(IconPosition::Start)
 938                        .disabled(env_var_set)
 939                        .when(env_var_set, |this| {
 940                            this.tooltip(Tooltip::text(format!(
 941                                "To reset your API key, \
 942                                unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
 943                            )))
 944                        })
 945                        .on_click(
 946                            cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
 947                        ),
 948                ).into_any()
 949        }
 950    }
 951}
 952
 953impl Render for ConfigurationView {
 954    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 955        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 956
 957        if self.load_credentials_task.is_some() {
 958            div().child(Label::new("Loading credentials...")).into_any()
 959        } else if self.should_render_api_key_editor(cx) {
 960            v_flex()
 961                .size_full()
 962                .on_action(cx.listener(Self::save_api_key))
 963                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 964                .child(
 965                    List::new()
 966                        .child(InstructionListItem::new(
 967                            "Create one by visiting",
 968                            Some("Mistral's console"),
 969                            Some("https://console.mistral.ai/api-keys"),
 970                        ))
 971                        .child(InstructionListItem::text_only(
 972                            "Ensure your Mistral account has credits",
 973                        ))
 974                        .child(InstructionListItem::text_only(
 975                            "Paste your API key below and hit enter to start using the assistant",
 976                        )),
 977                )
 978                .child(self.api_key_editor.clone())
 979                .child(
 980                    Label::new(
 981                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
 982                    )
 983                    .size(LabelSize::Small).color(Color::Muted),
 984                )
 985                .child(self.render_codestral_api_key_editor(cx))
 986                .into_any()
 987        } else {
 988            v_flex()
 989                .size_full()
 990                .child(
 991                    h_flex()
 992                        .mt_1()
 993                        .p_1()
 994                        .justify_between()
 995                        .rounded_md()
 996                        .border_1()
 997                        .border_color(cx.theme().colors().border)
 998                        .bg(cx.theme().colors().background)
 999                        .child(
1000                            h_flex()
1001                                .gap_1()
1002                                .child(Icon::new(IconName::Check).color(Color::Success))
1003                                .child(Label::new(if env_var_set {
1004                                    format!(
1005                                        "API key set in {API_KEY_ENV_VAR_NAME} environment variable"
1006                                    )
1007                                } else {
1008                                    let api_url = MistralLanguageModelProvider::api_url(cx);
1009                                    if api_url == MISTRAL_API_URL {
1010                                        "API key configured".to_string()
1011                                    } else {
1012                                        format!(
1013                                            "API key configured for {}",
1014                                            truncate_and_trailoff(&api_url, 32)
1015                                        )
1016                                    }
1017                                })),
1018                        )
1019                        .child(
1020                            Button::new("reset-key", "Reset Key")
1021                                .label_size(LabelSize::Small)
1022                                .icon(Some(IconName::Trash))
1023                                .icon_size(IconSize::Small)
1024                                .icon_position(IconPosition::Start)
1025                                .disabled(env_var_set)
1026                                .when(env_var_set, |this| {
1027                                    this.tooltip(Tooltip::text(format!(
1028                                        "To reset your API key, \
1029                                        unset the {API_KEY_ENV_VAR_NAME} environment variable."
1030                                    )))
1031                                })
1032                                .on_click(cx.listener(|this, _, window, cx| {
1033                                    this.reset_api_key(window, cx)
1034                                })),
1035                        ),
1036                )
1037                .child(self.render_codestral_api_key_editor(cx))
1038                .into_any()
1039        }
1040    }
1041}
1042
1043fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
1044    let fs = <dyn Fs>::global(cx);
1045    update_settings_file(fs, cx, move |settings, _| {
1046        settings
1047            .project
1048            .all_languages
1049            .features
1050            .get_or_insert_default()
1051            .edit_prediction_provider = Some(provider);
1052    });
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
1059
1060    #[test]
1061    fn test_into_mistral_basic_conversion() {
1062        let request = LanguageModelRequest {
1063            messages: vec![
1064                LanguageModelRequestMessage {
1065                    role: Role::System,
1066                    content: vec![MessageContent::Text("System prompt".into())],
1067                    cache: false,
1068                },
1069                LanguageModelRequestMessage {
1070                    role: Role::User,
1071                    content: vec![MessageContent::Text("Hello".into())],
1072                    cache: false,
1073                },
1074            ],
1075            temperature: Some(0.5),
1076            tools: vec![],
1077            tool_choice: None,
1078            thread_id: None,
1079            prompt_id: None,
1080            intent: None,
1081            mode: None,
1082            stop: vec![],
1083            thinking_allowed: true,
1084        };
1085
1086        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
1087
1088        assert_eq!(mistral_request.model, "mistral-small-latest");
1089        assert_eq!(mistral_request.temperature, Some(0.5));
1090        assert_eq!(mistral_request.messages.len(), 2);
1091        assert!(mistral_request.stream);
1092    }
1093
1094    #[test]
1095    fn test_into_mistral_with_image() {
1096        let request = LanguageModelRequest {
1097            messages: vec![LanguageModelRequestMessage {
1098                role: Role::User,
1099                content: vec![
1100                    MessageContent::Text("What's in this image?".into()),
1101                    MessageContent::Image(LanguageModelImage {
1102                        source: "base64data".into(),
1103                        size: Default::default(),
1104                    }),
1105                ],
1106                cache: false,
1107            }],
1108            tools: vec![],
1109            tool_choice: None,
1110            temperature: None,
1111            thread_id: None,
1112            prompt_id: None,
1113            intent: None,
1114            mode: None,
1115            stop: vec![],
1116            thinking_allowed: true,
1117        };
1118
1119        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
1120
1121        assert_eq!(mistral_request.messages.len(), 1);
1122        assert!(matches!(
1123            &mistral_request.messages[0],
1124            mistral::RequestMessage::User {
1125                content: mistral::MessageContent::Multipart { .. }
1126            }
1127        ));
1128
1129        if let mistral::RequestMessage::User {
1130            content: mistral::MessageContent::Multipart { content },
1131        } = &mistral_request.messages[0]
1132        {
1133            assert_eq!(content.len(), 2);
1134            assert!(matches!(
1135                &content[0],
1136                mistral::MessagePart::Text { text } if text == "What's in this image?"
1137            ));
1138            assert!(matches!(
1139                &content[1],
1140                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1141            ));
1142        }
1143    }
1144}