mistral.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4use editor::{Editor, EditorElement, EditorStyle};
   5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
   6use gpui::{
   7    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
   8};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  12    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  13    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  14    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  15    RateLimiter, Role, StopReason, TokenUsage,
  16};
  17use mistral::StreamResponse;
  18use schemars::JsonSchema;
  19use serde::{Deserialize, Serialize};
  20use settings::{Settings, SettingsStore};
  21use std::collections::HashMap;
  22use std::pin::Pin;
  23use std::str::FromStr;
  24use std::sync::Arc;
  25use strum::IntoEnumIterator;
  26use theme::ThemeSettings;
  27use ui::{Icon, IconName, List, Tooltip, prelude::*};
  28use util::ResultExt;
  29
  30use crate::{AllLanguageModelSettings, ui::InstructionListItem};
  31
  32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  34
  35#[derive(Default, Clone, Debug, PartialEq)]
  36pub struct MistralSettings {
  37    pub api_url: String,
  38    pub available_models: Vec<AvailableModel>,
  39}
  40
  41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  42pub struct AvailableModel {
  43    pub name: String,
  44    pub display_name: Option<String>,
  45    pub max_tokens: u64,
  46    pub max_output_tokens: Option<u64>,
  47    pub max_completion_tokens: Option<u64>,
  48    pub supports_tools: Option<bool>,
  49    pub supports_images: Option<bool>,
  50    pub supports_thinking: Option<bool>,
  51}
  52
  53pub struct MistralLanguageModelProvider {
  54    http_client: Arc<dyn HttpClient>,
  55    state: gpui::Entity<State>,
  56}
  57
  58pub struct State {
  59    api_key: Option<String>,
  60    api_key_from_env: bool,
  61    _subscription: Subscription,
  62}
  63
  64const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
  65
  66impl State {
  67    fn is_authenticated(&self) -> bool {
  68        self.api_key.is_some()
  69    }
  70
  71    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
  72        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  73        let api_url = AllLanguageModelSettings::get_global(cx)
  74            .mistral
  75            .api_url
  76            .clone();
  77        cx.spawn(async move |this, cx| {
  78            credentials_provider
  79                .delete_credentials(&api_url, cx)
  80                .await
  81                .log_err();
  82            this.update(cx, |this, cx| {
  83                this.api_key = None;
  84                this.api_key_from_env = false;
  85                cx.notify();
  86            })
  87        })
  88    }
  89
  90    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
  91        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  92        let api_url = AllLanguageModelSettings::get_global(cx)
  93            .mistral
  94            .api_url
  95            .clone();
  96        cx.spawn(async move |this, cx| {
  97            credentials_provider
  98                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
  99                .await?;
 100            this.update(cx, |this, cx| {
 101                this.api_key = Some(api_key);
 102                cx.notify();
 103            })
 104        })
 105    }
 106
 107    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 108        if self.is_authenticated() {
 109            return Task::ready(Ok(()));
 110        }
 111
 112        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 113        let api_url = AllLanguageModelSettings::get_global(cx)
 114            .mistral
 115            .api_url
 116            .clone();
 117        cx.spawn(async move |this, cx| {
 118            let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
 119                (api_key, true)
 120            } else {
 121                let (_, api_key) = credentials_provider
 122                    .read_credentials(&api_url, cx)
 123                    .await?
 124                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 125                (
 126                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 127                    false,
 128                )
 129            };
 130            this.update(cx, |this, cx| {
 131                this.api_key = Some(api_key);
 132                this.api_key_from_env = from_env;
 133                cx.notify();
 134            })?;
 135
 136            Ok(())
 137        })
 138    }
 139}
 140
 141impl MistralLanguageModelProvider {
 142    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 143        let state = cx.new(|cx| State {
 144            api_key: None,
 145            api_key_from_env: false,
 146            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 147                cx.notify();
 148            }),
 149        });
 150
 151        Self { http_client, state }
 152    }
 153
 154    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 155        Arc::new(MistralLanguageModel {
 156            id: LanguageModelId::from(model.id().to_string()),
 157            model,
 158            state: self.state.clone(),
 159            http_client: self.http_client.clone(),
 160            request_limiter: RateLimiter::new(4),
 161        })
 162    }
 163}
 164
 165impl LanguageModelProviderState for MistralLanguageModelProvider {
 166    type ObservableEntity = State;
 167
 168    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 169        Some(self.state.clone())
 170    }
 171}
 172
 173impl LanguageModelProvider for MistralLanguageModelProvider {
 174    fn id(&self) -> LanguageModelProviderId {
 175        PROVIDER_ID
 176    }
 177
 178    fn name(&self) -> LanguageModelProviderName {
 179        PROVIDER_NAME
 180    }
 181
 182    fn icon(&self) -> IconName {
 183        IconName::AiMistral
 184    }
 185
 186    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 187        Some(self.create_language_model(mistral::Model::default()))
 188    }
 189
 190    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 191        Some(self.create_language_model(mistral::Model::default_fast()))
 192    }
 193
 194    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 195        let mut models = BTreeMap::default();
 196
 197        // Add base models from mistral::Model::iter()
 198        for model in mistral::Model::iter() {
 199            if !matches!(model, mistral::Model::Custom { .. }) {
 200                models.insert(model.id().to_string(), model);
 201            }
 202        }
 203
 204        // Override with available models from settings
 205        for model in &AllLanguageModelSettings::get_global(cx)
 206            .mistral
 207            .available_models
 208        {
 209            models.insert(
 210                model.name.clone(),
 211                mistral::Model::Custom {
 212                    name: model.name.clone(),
 213                    display_name: model.display_name.clone(),
 214                    max_tokens: model.max_tokens,
 215                    max_output_tokens: model.max_output_tokens,
 216                    max_completion_tokens: model.max_completion_tokens,
 217                    supports_tools: model.supports_tools,
 218                    supports_images: model.supports_images,
 219                    supports_thinking: model.supports_thinking,
 220                },
 221            );
 222        }
 223
 224        models
 225            .into_values()
 226            .map(|model| {
 227                Arc::new(MistralLanguageModel {
 228                    id: LanguageModelId::from(model.id().to_string()),
 229                    model,
 230                    state: self.state.clone(),
 231                    http_client: self.http_client.clone(),
 232                    request_limiter: RateLimiter::new(4),
 233                }) as Arc<dyn LanguageModel>
 234            })
 235            .collect()
 236    }
 237
 238    fn is_authenticated(&self, cx: &App) -> bool {
 239        self.state.read(cx).is_authenticated()
 240    }
 241
 242    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 243        self.state.update(cx, |state, cx| state.authenticate(cx))
 244    }
 245
 246    fn configuration_view(
 247        &self,
 248        _target_agent: language_model::ConfigurationViewTargetAgent,
 249        window: &mut Window,
 250        cx: &mut App,
 251    ) -> AnyView {
 252        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 253            .into()
 254    }
 255
 256    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 257        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 258    }
 259}
 260
 261pub struct MistralLanguageModel {
 262    id: LanguageModelId,
 263    model: mistral::Model,
 264    state: gpui::Entity<State>,
 265    http_client: Arc<dyn HttpClient>,
 266    request_limiter: RateLimiter,
 267}
 268
 269impl MistralLanguageModel {
 270    fn stream_completion(
 271        &self,
 272        request: mistral::Request,
 273        cx: &AsyncApp,
 274    ) -> BoxFuture<
 275        'static,
 276        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 277    > {
 278        let http_client = self.http_client.clone();
 279        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 280            let settings = &AllLanguageModelSettings::get_global(cx).mistral;
 281            (state.api_key.clone(), settings.api_url.clone())
 282        }) else {
 283            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
 284        };
 285
 286        let future = self.request_limiter.stream(async move {
 287            let api_key = api_key.context("Missing Mistral API Key")?;
 288            let request =
 289                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 290            let response = request.await?;
 291            Ok(response)
 292        });
 293
 294        async move { Ok(future.await?.boxed()) }.boxed()
 295    }
 296}
 297
 298impl LanguageModel for MistralLanguageModel {
 299    fn id(&self) -> LanguageModelId {
 300        self.id.clone()
 301    }
 302
 303    fn name(&self) -> LanguageModelName {
 304        LanguageModelName::from(self.model.display_name().to_string())
 305    }
 306
 307    fn provider_id(&self) -> LanguageModelProviderId {
 308        PROVIDER_ID
 309    }
 310
 311    fn provider_name(&self) -> LanguageModelProviderName {
 312        PROVIDER_NAME
 313    }
 314
 315    fn supports_tools(&self) -> bool {
 316        self.model.supports_tools()
 317    }
 318
 319    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 320        self.model.supports_tools()
 321    }
 322
 323    fn supports_images(&self) -> bool {
 324        self.model.supports_images()
 325    }
 326
 327    fn telemetry_id(&self) -> String {
 328        format!("mistral/{}", self.model.id())
 329    }
 330
 331    fn max_token_count(&self) -> u64 {
 332        self.model.max_token_count()
 333    }
 334
 335    fn max_output_tokens(&self) -> Option<u64> {
 336        self.model.max_output_tokens()
 337    }
 338
 339    fn count_tokens(
 340        &self,
 341        request: LanguageModelRequest,
 342        cx: &App,
 343    ) -> BoxFuture<'static, Result<u64>> {
 344        cx.background_spawn(async move {
 345            let messages = request
 346                .messages
 347                .into_iter()
 348                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 349                    role: match message.role {
 350                        Role::User => "user".into(),
 351                        Role::Assistant => "assistant".into(),
 352                        Role::System => "system".into(),
 353                    },
 354                    content: Some(message.string_contents()),
 355                    name: None,
 356                    function_call: None,
 357                })
 358                .collect::<Vec<_>>();
 359
 360            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 361        })
 362        .boxed()
 363    }
 364
 365    fn stream_completion(
 366        &self,
 367        request: LanguageModelRequest,
 368        cx: &AsyncApp,
 369    ) -> BoxFuture<
 370        'static,
 371        Result<
 372            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 373            LanguageModelCompletionError,
 374        >,
 375    > {
 376        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
 377        let stream = self.stream_completion(request, cx);
 378
 379        async move {
 380            let stream = stream.await?;
 381            let mapper = MistralEventMapper::new();
 382            Ok(mapper.map_stream(stream).boxed())
 383        }
 384        .boxed()
 385    }
 386}
 387
 388pub fn into_mistral(
 389    request: LanguageModelRequest,
 390    model: mistral::Model,
 391    max_output_tokens: Option<u64>,
 392) -> mistral::Request {
 393    let stream = true;
 394
 395    let mut messages = Vec::new();
 396    for message in &request.messages {
 397        match message.role {
 398            Role::User => {
 399                let mut message_content = mistral::MessageContent::empty();
 400                for content in &message.content {
 401                    match content {
 402                        MessageContent::Text(text) => {
 403                            message_content
 404                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 405                        }
 406                        MessageContent::Image(image_content) => {
 407                            if model.supports_images() {
 408                                message_content.push_part(mistral::MessagePart::ImageUrl {
 409                                    image_url: image_content.to_base64_url(),
 410                                });
 411                            }
 412                        }
 413                        MessageContent::Thinking { text, .. } => {
 414                            if model.supports_thinking() {
 415                                message_content.push_part(mistral::MessagePart::Thinking {
 416                                    thinking: vec![mistral::ThinkingPart::Text {
 417                                        text: text.clone(),
 418                                    }],
 419                                });
 420                            }
 421                        }
 422                        MessageContent::RedactedThinking(_) => {}
 423                        MessageContent::ToolUse(_) => {
 424                            // Tool use is not supported in User messages for Mistral
 425                        }
 426                        MessageContent::ToolResult(tool_result) => {
 427                            let tool_content = match &tool_result.content {
 428                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 429                                LanguageModelToolResultContent::Image(_) => {
 430                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 431                                }
 432                            };
 433                            messages.push(mistral::RequestMessage::Tool {
 434                                content: tool_content,
 435                                tool_call_id: tool_result.tool_use_id.to_string(),
 436                            });
 437                        }
 438                    }
 439                }
 440                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 441                {
 442                    messages.push(mistral::RequestMessage::User {
 443                        content: message_content,
 444                    });
 445                }
 446            }
 447            Role::Assistant => {
 448                for content in &message.content {
 449                    match content {
 450                        MessageContent::Text(text) => {
 451                            messages.push(mistral::RequestMessage::Assistant {
 452                                content: Some(mistral::MessageContent::Plain {
 453                                    content: text.clone(),
 454                                }),
 455                                tool_calls: Vec::new(),
 456                            });
 457                        }
 458                        MessageContent::Thinking { text, .. } => {
 459                            if model.supports_thinking() {
 460                                messages.push(mistral::RequestMessage::Assistant {
 461                                    content: Some(mistral::MessageContent::Multipart {
 462                                        content: vec![mistral::MessagePart::Thinking {
 463                                            thinking: vec![mistral::ThinkingPart::Text {
 464                                                text: text.clone(),
 465                                            }],
 466                                        }],
 467                                    }),
 468                                    tool_calls: Vec::new(),
 469                                });
 470                            }
 471                        }
 472                        MessageContent::RedactedThinking(_) => {}
 473                        MessageContent::Image(_) => {}
 474                        MessageContent::ToolUse(tool_use) => {
 475                            let tool_call = mistral::ToolCall {
 476                                id: tool_use.id.to_string(),
 477                                content: mistral::ToolCallContent::Function {
 478                                    function: mistral::FunctionContent {
 479                                        name: tool_use.name.to_string(),
 480                                        arguments: serde_json::to_string(&tool_use.input)
 481                                            .unwrap_or_default(),
 482                                    },
 483                                },
 484                            };
 485
 486                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 487                                messages.last_mut()
 488                            {
 489                                tool_calls.push(tool_call);
 490                            } else {
 491                                messages.push(mistral::RequestMessage::Assistant {
 492                                    content: None,
 493                                    tool_calls: vec![tool_call],
 494                                });
 495                            }
 496                        }
 497                        MessageContent::ToolResult(_) => {
 498                            // Tool results are not supported in Assistant messages
 499                        }
 500                    }
 501                }
 502            }
 503            Role::System => {
 504                for content in &message.content {
 505                    match content {
 506                        MessageContent::Text(text) => {
 507                            messages.push(mistral::RequestMessage::System {
 508                                content: mistral::MessageContent::Plain {
 509                                    content: text.clone(),
 510                                },
 511                            });
 512                        }
 513                        MessageContent::Thinking { text, .. } => {
 514                            if model.supports_thinking() {
 515                                messages.push(mistral::RequestMessage::System {
 516                                    content: mistral::MessageContent::Multipart {
 517                                        content: vec![mistral::MessagePart::Thinking {
 518                                            thinking: vec![mistral::ThinkingPart::Text {
 519                                                text: text.clone(),
 520                                            }],
 521                                        }],
 522                                    },
 523                                });
 524                            }
 525                        }
 526                        MessageContent::RedactedThinking(_) => {}
 527                        MessageContent::Image(_)
 528                        | MessageContent::ToolUse(_)
 529                        | MessageContent::ToolResult(_) => {
 530                            // Images and tools are not supported in System messages
 531                        }
 532                    }
 533                }
 534            }
 535        }
 536    }
 537
 538    mistral::Request {
 539        model: model.id().to_string(),
 540        messages,
 541        stream,
 542        max_tokens: max_output_tokens,
 543        temperature: request.temperature,
 544        response_format: None,
 545        tool_choice: match request.tool_choice {
 546            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 547                Some(mistral::ToolChoice::Auto)
 548            }
 549            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 550                Some(mistral::ToolChoice::Any)
 551            }
 552            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 553            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 554            _ => None,
 555        },
 556        parallel_tool_calls: if !request.tools.is_empty() {
 557            Some(false)
 558        } else {
 559            None
 560        },
 561        tools: request
 562            .tools
 563            .into_iter()
 564            .map(|tool| mistral::ToolDefinition::Function {
 565                function: mistral::FunctionDefinition {
 566                    name: tool.name,
 567                    description: Some(tool.description),
 568                    parameters: Some(tool.input_schema),
 569                },
 570            })
 571            .collect(),
 572    }
 573}
 574
 575pub struct MistralEventMapper {
 576    tool_calls_by_index: HashMap<usize, RawToolCall>,
 577}
 578
 579impl MistralEventMapper {
 580    pub fn new() -> Self {
 581        Self {
 582            tool_calls_by_index: HashMap::default(),
 583        }
 584    }
 585
 586    pub fn map_stream(
 587        mut self,
 588        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 589    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 590    {
 591        events.flat_map(move |event| {
 592            futures::stream::iter(match event {
 593                Ok(event) => self.map_event(event),
 594                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 595            })
 596        })
 597    }
 598
 599    pub fn map_event(
 600        &mut self,
 601        event: mistral::StreamResponse,
 602    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 603        let Some(choice) = event.choices.first() else {
 604            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 605                "Response contained no choices"
 606            )))];
 607        };
 608
 609        let mut events = Vec::new();
 610        if let Some(content) = choice.delta.content.as_ref() {
 611            match content {
 612                mistral::MessageContentDelta::Text(text) => {
 613                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 614                }
 615                mistral::MessageContentDelta::Parts(parts) => {
 616                    for part in parts {
 617                        match part {
 618                            mistral::MessagePart::Text { text } => {
 619                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 620                            }
 621                            mistral::MessagePart::Thinking { thinking } => {
 622                                for tp in thinking.iter().cloned() {
 623                                    match tp {
 624                                        mistral::ThinkingPart::Text { text } => {
 625                                            events.push(Ok(
 626                                                LanguageModelCompletionEvent::Thinking {
 627                                                    text,
 628                                                    signature: None,
 629                                                },
 630                                            ));
 631                                        }
 632                                    }
 633                                }
 634                            }
 635                            mistral::MessagePart::ImageUrl { .. } => {
 636                                // We currently don't emit a separate event for images in responses.
 637                            }
 638                        }
 639                    }
 640                }
 641            }
 642        }
 643
 644        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 645            for tool_call in tool_calls {
 646                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 647
 648                if let Some(tool_id) = tool_call.id.clone() {
 649                    entry.id = tool_id;
 650                }
 651
 652                if let Some(function) = tool_call.function.as_ref() {
 653                    if let Some(name) = function.name.clone() {
 654                        entry.name = name;
 655                    }
 656
 657                    if let Some(arguments) = function.arguments.clone() {
 658                        entry.arguments.push_str(&arguments);
 659                    }
 660                }
 661            }
 662        }
 663
 664        if let Some(usage) = event.usage {
 665            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 666                input_tokens: usage.prompt_tokens,
 667                output_tokens: usage.completion_tokens,
 668                cache_creation_input_tokens: 0,
 669                cache_read_input_tokens: 0,
 670            })));
 671        }
 672
 673        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 674            match finish_reason {
 675                "stop" => {
 676                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 677                }
 678                "tool_calls" => {
 679                    events.extend(self.process_tool_calls());
 680                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 681                }
 682                unexpected => {
 683                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 684                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 685                }
 686            }
 687        }
 688
 689        events
 690    }
 691
 692    fn process_tool_calls(
 693        &mut self,
 694    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 695        let mut results = Vec::new();
 696
 697        for (_, tool_call) in self.tool_calls_by_index.drain() {
 698            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 699                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 700                    "Received incomplete tool call: missing id or name"
 701                ))));
 702                continue;
 703            }
 704
 705            match serde_json::Value::from_str(&tool_call.arguments) {
 706                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 707                    LanguageModelToolUse {
 708                        id: tool_call.id.into(),
 709                        name: tool_call.name.into(),
 710                        is_input_complete: true,
 711                        input,
 712                        raw_input: tool_call.arguments,
 713                    },
 714                ))),
 715                Err(error) => {
 716                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 717                        id: tool_call.id.into(),
 718                        tool_name: tool_call.name.into(),
 719                        raw_input: tool_call.arguments.into(),
 720                        json_parse_error: error.to_string(),
 721                    }))
 722                }
 723            }
 724        }
 725
 726        results
 727    }
 728}
 729
 730#[derive(Default)]
 731struct RawToolCall {
 732    id: String,
 733    name: String,
 734    arguments: String,
 735}
 736
 737struct ConfigurationView {
 738    api_key_editor: Entity<Editor>,
 739    state: gpui::Entity<State>,
 740    load_credentials_task: Option<Task<()>>,
 741}
 742
 743impl ConfigurationView {
 744    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 745        let api_key_editor = cx.new(|cx| {
 746            let mut editor = Editor::single_line(window, cx);
 747            editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", window, cx);
 748            editor
 749        });
 750
 751        cx.observe(&state, |_, _, cx| {
 752            cx.notify();
 753        })
 754        .detach();
 755
 756        let load_credentials_task = Some(cx.spawn_in(window, {
 757            let state = state.clone();
 758            async move |this, cx| {
 759                if let Some(task) = state
 760                    .update(cx, |state, cx| state.authenticate(cx))
 761                    .log_err()
 762                {
 763                    // We don't log an error, because "not signed in" is also an error.
 764                    let _ = task.await;
 765                }
 766
 767                this.update(cx, |this, cx| {
 768                    this.load_credentials_task = None;
 769                    cx.notify();
 770                })
 771                .log_err();
 772            }
 773        }));
 774
 775        Self {
 776            api_key_editor,
 777            state,
 778            load_credentials_task,
 779        }
 780    }
 781
 782    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 783        let api_key = self.api_key_editor.read(cx).text(cx);
 784        if api_key.is_empty() {
 785            return;
 786        }
 787
 788        let state = self.state.clone();
 789        cx.spawn_in(window, async move |_, cx| {
 790            state
 791                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 792                .await
 793        })
 794        .detach_and_log_err(cx);
 795
 796        cx.notify();
 797    }
 798
 799    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 800        self.api_key_editor
 801            .update(cx, |editor, cx| editor.set_text("", window, cx));
 802
 803        let state = self.state.clone();
 804        cx.spawn_in(window, async move |_, cx| {
 805            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 806        })
 807        .detach_and_log_err(cx);
 808
 809        cx.notify();
 810    }
 811
 812    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 813        let settings = ThemeSettings::get_global(cx);
 814        let text_style = TextStyle {
 815            color: cx.theme().colors().text,
 816            font_family: settings.ui_font.family.clone(),
 817            font_features: settings.ui_font.features.clone(),
 818            font_fallbacks: settings.ui_font.fallbacks.clone(),
 819            font_size: rems(0.875).into(),
 820            font_weight: settings.ui_font.weight,
 821            font_style: FontStyle::Normal,
 822            line_height: relative(1.3),
 823            white_space: WhiteSpace::Normal,
 824            ..Default::default()
 825        };
 826        EditorElement::new(
 827            &self.api_key_editor,
 828            EditorStyle {
 829                background: cx.theme().colors().editor_background,
 830                local_player: cx.theme().players().local(),
 831                text: text_style,
 832                ..Default::default()
 833            },
 834        )
 835    }
 836
 837    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 838        !self.state.read(cx).is_authenticated()
 839    }
 840}
 841
 842impl Render for ConfigurationView {
 843    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 844        let env_var_set = self.state.read(cx).api_key_from_env;
 845
 846        if self.load_credentials_task.is_some() {
 847            div().child(Label::new("Loading credentials...")).into_any()
 848        } else if self.should_render_editor(cx) {
 849            v_flex()
 850                .size_full()
 851                .on_action(cx.listener(Self::save_api_key))
 852                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 853                .child(
 854                    List::new()
 855                        .child(InstructionListItem::new(
 856                            "Create one by visiting",
 857                            Some("Mistral's console"),
 858                            Some("https://console.mistral.ai/api-keys"),
 859                        ))
 860                        .child(InstructionListItem::text_only(
 861                            "Ensure your Mistral account has credits",
 862                        ))
 863                        .child(InstructionListItem::text_only(
 864                            "Paste your API key below and hit enter to start using the assistant",
 865                        )),
 866                )
 867                .child(
 868                    h_flex()
 869                        .w_full()
 870                        .my_2()
 871                        .px_2()
 872                        .py_1()
 873                        .bg(cx.theme().colors().editor_background)
 874                        .border_1()
 875                        .border_color(cx.theme().colors().border)
 876                        .rounded_sm()
 877                        .child(self.render_api_key_editor(cx)),
 878                )
 879                .child(
 880                    Label::new(
 881                        format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
 882                    )
 883                    .size(LabelSize::Small).color(Color::Muted),
 884                )
 885                .into_any()
 886        } else {
 887            h_flex()
 888                .mt_1()
 889                .p_1()
 890                .justify_between()
 891                .rounded_md()
 892                .border_1()
 893                .border_color(cx.theme().colors().border)
 894                .bg(cx.theme().colors().background)
 895                .child(
 896                    h_flex()
 897                        .gap_1()
 898                        .child(Icon::new(IconName::Check).color(Color::Success))
 899                        .child(Label::new(if env_var_set {
 900                            format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
 901                        } else {
 902                            "API key configured.".to_string()
 903                        })),
 904                )
 905                .child(
 906                    Button::new("reset-key", "Reset Key")
 907                        .label_size(LabelSize::Small)
 908                        .icon(Some(IconName::Trash))
 909                        .icon_size(IconSize::Small)
 910                        .icon_position(IconPosition::Start)
 911                        .disabled(env_var_set)
 912                        .when(env_var_set, |this| {
 913                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
 914                        })
 915                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
 916                )
 917                .into_any()
 918        }
 919    }
 920}
 921
 922#[cfg(test)]
 923mod tests {
 924    use super::*;
 925    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
 926
 927    #[test]
 928    fn test_into_mistral_basic_conversion() {
 929        let request = LanguageModelRequest {
 930            messages: vec![
 931                LanguageModelRequestMessage {
 932                    role: Role::System,
 933                    content: vec![MessageContent::Text("System prompt".into())],
 934                    cache: false,
 935                },
 936                LanguageModelRequestMessage {
 937                    role: Role::User,
 938                    content: vec![MessageContent::Text("Hello".into())],
 939                    cache: false,
 940                },
 941            ],
 942            temperature: Some(0.5),
 943            tools: vec![],
 944            tool_choice: None,
 945            thread_id: None,
 946            prompt_id: None,
 947            intent: None,
 948            mode: None,
 949            stop: vec![],
 950            thinking_allowed: true,
 951        };
 952
 953        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
 954
 955        assert_eq!(mistral_request.model, "mistral-small-latest");
 956        assert_eq!(mistral_request.temperature, Some(0.5));
 957        assert_eq!(mistral_request.messages.len(), 2);
 958        assert!(mistral_request.stream);
 959    }
 960
 961    #[test]
 962    fn test_into_mistral_with_image() {
 963        let request = LanguageModelRequest {
 964            messages: vec![LanguageModelRequestMessage {
 965                role: Role::User,
 966                content: vec![
 967                    MessageContent::Text("What's in this image?".into()),
 968                    MessageContent::Image(LanguageModelImage {
 969                        source: "base64data".into(),
 970                        size: Default::default(),
 971                    }),
 972                ],
 973                cache: false,
 974            }],
 975            tools: vec![],
 976            tool_choice: None,
 977            temperature: None,
 978            thread_id: None,
 979            prompt_id: None,
 980            intent: None,
 981            mode: None,
 982            stop: vec![],
 983            thinking_allowed: true,
 984        };
 985
 986        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
 987
 988        assert_eq!(mistral_request.messages.len(), 1);
 989        assert!(matches!(
 990            &mistral_request.messages[0],
 991            mistral::RequestMessage::User {
 992                content: mistral::MessageContent::Multipart { .. }
 993            }
 994        ));
 995
 996        if let mistral::RequestMessage::User {
 997            content: mistral::MessageContent::Multipart { content },
 998        } = &mistral_request.messages[0]
 999        {
1000            assert_eq!(content.len(), 2);
1001            assert!(matches!(
1002                &content[0],
1003                mistral::MessagePart::Text { text } if text == "What's in this image?"
1004            ));
1005            assert!(matches!(
1006                &content[1],
1007                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1008            ));
1009        }
1010    }
1011}