mistral.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4use editor::{Editor, EditorElement, EditorStyle};
   5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
   6use gpui::{
   7    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
   8};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  12    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  13    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  14    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  15    RateLimiter, Role, StopReason, TokenUsage,
  16};
  17use mistral::StreamResponse;
  18use schemars::JsonSchema;
  19use serde::{Deserialize, Serialize};
  20use settings::{Settings, SettingsStore};
  21use std::collections::HashMap;
  22use std::pin::Pin;
  23use std::str::FromStr;
  24use std::sync::Arc;
  25use strum::IntoEnumIterator;
  26use theme::ThemeSettings;
  27use ui::{Icon, IconName, List, Tooltip, prelude::*};
  28use util::ResultExt;
  29
  30use crate::{AllLanguageModelSettings, ui::InstructionListItem};
  31
  32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
  33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
  34
  35#[derive(Default, Clone, Debug, PartialEq)]
  36pub struct MistralSettings {
  37    pub api_url: String,
  38    pub available_models: Vec<AvailableModel>,
  39}
  40
  41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  42pub struct AvailableModel {
  43    pub name: String,
  44    pub display_name: Option<String>,
  45    pub max_tokens: u64,
  46    pub max_output_tokens: Option<u64>,
  47    pub max_completion_tokens: Option<u64>,
  48    pub supports_tools: Option<bool>,
  49    pub supports_images: Option<bool>,
  50    pub supports_thinking: Option<bool>,
  51}
  52
  53pub struct MistralLanguageModelProvider {
  54    http_client: Arc<dyn HttpClient>,
  55    state: gpui::Entity<State>,
  56}
  57
  58pub struct State {
  59    api_key: Option<String>,
  60    api_key_from_env: bool,
  61    _subscription: Subscription,
  62}
  63
  64const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
  65
  66impl State {
  67    fn is_authenticated(&self) -> bool {
  68        self.api_key.is_some()
  69    }
  70
  71    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
  72        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  73        let api_url = AllLanguageModelSettings::get_global(cx)
  74            .mistral
  75            .api_url
  76            .clone();
  77        cx.spawn(async move |this, cx| {
  78            credentials_provider
  79                .delete_credentials(&api_url, &cx)
  80                .await
  81                .log_err();
  82            this.update(cx, |this, cx| {
  83                this.api_key = None;
  84                this.api_key_from_env = false;
  85                cx.notify();
  86            })
  87        })
  88    }
  89
  90    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
  91        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  92        let api_url = AllLanguageModelSettings::get_global(cx)
  93            .mistral
  94            .api_url
  95            .clone();
  96        cx.spawn(async move |this, cx| {
  97            credentials_provider
  98                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
  99                .await?;
 100            this.update(cx, |this, cx| {
 101                this.api_key = Some(api_key);
 102                cx.notify();
 103            })
 104        })
 105    }
 106
 107    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 108        if self.is_authenticated() {
 109            return Task::ready(Ok(()));
 110        }
 111
 112        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 113        let api_url = AllLanguageModelSettings::get_global(cx)
 114            .mistral
 115            .api_url
 116            .clone();
 117        cx.spawn(async move |this, cx| {
 118            let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
 119                (api_key, true)
 120            } else {
 121                let (_, api_key) = credentials_provider
 122                    .read_credentials(&api_url, &cx)
 123                    .await?
 124                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 125                (
 126                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 127                    false,
 128                )
 129            };
 130            this.update(cx, |this, cx| {
 131                this.api_key = Some(api_key);
 132                this.api_key_from_env = from_env;
 133                cx.notify();
 134            })?;
 135
 136            Ok(())
 137        })
 138    }
 139}
 140
 141impl MistralLanguageModelProvider {
 142    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 143        let state = cx.new(|cx| State {
 144            api_key: None,
 145            api_key_from_env: false,
 146            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 147                cx.notify();
 148            }),
 149        });
 150
 151        Self { http_client, state }
 152    }
 153
 154    fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
 155        Arc::new(MistralLanguageModel {
 156            id: LanguageModelId::from(model.id().to_string()),
 157            model,
 158            state: self.state.clone(),
 159            http_client: self.http_client.clone(),
 160            request_limiter: RateLimiter::new(4),
 161        })
 162    }
 163}
 164
 165impl LanguageModelProviderState for MistralLanguageModelProvider {
 166    type ObservableEntity = State;
 167
 168    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 169        Some(self.state.clone())
 170    }
 171}
 172
 173impl LanguageModelProvider for MistralLanguageModelProvider {
 174    fn id(&self) -> LanguageModelProviderId {
 175        PROVIDER_ID
 176    }
 177
 178    fn name(&self) -> LanguageModelProviderName {
 179        PROVIDER_NAME
 180    }
 181
 182    fn icon(&self) -> IconName {
 183        IconName::AiMistral
 184    }
 185
 186    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 187        Some(self.create_language_model(mistral::Model::default()))
 188    }
 189
 190    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 191        Some(self.create_language_model(mistral::Model::default_fast()))
 192    }
 193
 194    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 195        let mut models = BTreeMap::default();
 196
 197        // Add base models from mistral::Model::iter()
 198        for model in mistral::Model::iter() {
 199            if !matches!(model, mistral::Model::Custom { .. }) {
 200                models.insert(model.id().to_string(), model);
 201            }
 202        }
 203
 204        // Override with available models from settings
 205        for model in &AllLanguageModelSettings::get_global(cx)
 206            .mistral
 207            .available_models
 208        {
 209            models.insert(
 210                model.name.clone(),
 211                mistral::Model::Custom {
 212                    name: model.name.clone(),
 213                    display_name: model.display_name.clone(),
 214                    max_tokens: model.max_tokens,
 215                    max_output_tokens: model.max_output_tokens,
 216                    max_completion_tokens: model.max_completion_tokens,
 217                    supports_tools: model.supports_tools,
 218                    supports_images: model.supports_images,
 219                    supports_thinking: model.supports_thinking,
 220                },
 221            );
 222        }
 223
 224        models
 225            .into_values()
 226            .map(|model| {
 227                Arc::new(MistralLanguageModel {
 228                    id: LanguageModelId::from(model.id().to_string()),
 229                    model,
 230                    state: self.state.clone(),
 231                    http_client: self.http_client.clone(),
 232                    request_limiter: RateLimiter::new(4),
 233                }) as Arc<dyn LanguageModel>
 234            })
 235            .collect()
 236    }
 237
 238    fn is_authenticated(&self, cx: &App) -> bool {
 239        self.state.read(cx).is_authenticated()
 240    }
 241
 242    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 243        self.state.update(cx, |state, cx| state.authenticate(cx))
 244    }
 245
 246    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
 247        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 248            .into()
 249    }
 250
 251    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 252        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 253    }
 254}
 255
 256pub struct MistralLanguageModel {
 257    id: LanguageModelId,
 258    model: mistral::Model,
 259    state: gpui::Entity<State>,
 260    http_client: Arc<dyn HttpClient>,
 261    request_limiter: RateLimiter,
 262}
 263
 264impl MistralLanguageModel {
 265    fn stream_completion(
 266        &self,
 267        request: mistral::Request,
 268        cx: &AsyncApp,
 269    ) -> BoxFuture<
 270        'static,
 271        Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
 272    > {
 273        let http_client = self.http_client.clone();
 274        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 275            let settings = &AllLanguageModelSettings::get_global(cx).mistral;
 276            (state.api_key.clone(), settings.api_url.clone())
 277        }) else {
 278            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
 279        };
 280
 281        let future = self.request_limiter.stream(async move {
 282            let api_key = api_key.context("Missing Mistral API Key")?;
 283            let request =
 284                mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 285            let response = request.await?;
 286            Ok(response)
 287        });
 288
 289        async move { Ok(future.await?.boxed()) }.boxed()
 290    }
 291}
 292
 293impl LanguageModel for MistralLanguageModel {
 294    fn id(&self) -> LanguageModelId {
 295        self.id.clone()
 296    }
 297
 298    fn name(&self) -> LanguageModelName {
 299        LanguageModelName::from(self.model.display_name().to_string())
 300    }
 301
 302    fn provider_id(&self) -> LanguageModelProviderId {
 303        PROVIDER_ID
 304    }
 305
 306    fn provider_name(&self) -> LanguageModelProviderName {
 307        PROVIDER_NAME
 308    }
 309
 310    fn supports_tools(&self) -> bool {
 311        self.model.supports_tools()
 312    }
 313
 314    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 315        self.model.supports_tools()
 316    }
 317
 318    fn supports_images(&self) -> bool {
 319        self.model.supports_images()
 320    }
 321
 322    fn telemetry_id(&self) -> String {
 323        format!("mistral/{}", self.model.id())
 324    }
 325
 326    fn max_token_count(&self) -> u64 {
 327        self.model.max_token_count()
 328    }
 329
 330    fn max_output_tokens(&self) -> Option<u64> {
 331        self.model.max_output_tokens()
 332    }
 333
 334    fn count_tokens(
 335        &self,
 336        request: LanguageModelRequest,
 337        cx: &App,
 338    ) -> BoxFuture<'static, Result<u64>> {
 339        cx.background_spawn(async move {
 340            let messages = request
 341                .messages
 342                .into_iter()
 343                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 344                    role: match message.role {
 345                        Role::User => "user".into(),
 346                        Role::Assistant => "assistant".into(),
 347                        Role::System => "system".into(),
 348                    },
 349                    content: Some(message.string_contents()),
 350                    name: None,
 351                    function_call: None,
 352                })
 353                .collect::<Vec<_>>();
 354
 355            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 356        })
 357        .boxed()
 358    }
 359
 360    fn stream_completion(
 361        &self,
 362        request: LanguageModelRequest,
 363        cx: &AsyncApp,
 364    ) -> BoxFuture<
 365        'static,
 366        Result<
 367            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 368            LanguageModelCompletionError,
 369        >,
 370    > {
 371        let request = into_mistral(request, self.model.clone(), self.max_output_tokens());
 372        let stream = self.stream_completion(request, cx);
 373
 374        async move {
 375            let stream = stream.await?;
 376            let mapper = MistralEventMapper::new();
 377            Ok(mapper.map_stream(stream).boxed())
 378        }
 379        .boxed()
 380    }
 381}
 382
 383pub fn into_mistral(
 384    request: LanguageModelRequest,
 385    model: mistral::Model,
 386    max_output_tokens: Option<u64>,
 387) -> mistral::Request {
 388    let stream = true;
 389
 390    let mut messages = Vec::new();
 391    for message in &request.messages {
 392        match message.role {
 393            Role::User => {
 394                let mut message_content = mistral::MessageContent::empty();
 395                for content in &message.content {
 396                    match content {
 397                        MessageContent::Text(text) => {
 398                            message_content
 399                                .push_part(mistral::MessagePart::Text { text: text.clone() });
 400                        }
 401                        MessageContent::Image(image_content) => {
 402                            if model.supports_images() {
 403                                message_content.push_part(mistral::MessagePart::ImageUrl {
 404                                    image_url: image_content.to_base64_url(),
 405                                });
 406                            }
 407                        }
 408                        MessageContent::Thinking { text, .. } => {
 409                            if model.supports_thinking() {
 410                                message_content.push_part(mistral::MessagePart::Thinking {
 411                                    thinking: vec![mistral::ThinkingPart::Text {
 412                                        text: text.clone(),
 413                                    }],
 414                                });
 415                            }
 416                        }
 417                        MessageContent::RedactedThinking(_) => {}
 418                        MessageContent::ToolUse(_) => {
 419                            // Tool use is not supported in User messages for Mistral
 420                        }
 421                        MessageContent::ToolResult(tool_result) => {
 422                            let tool_content = match &tool_result.content {
 423                                LanguageModelToolResultContent::Text(text) => text.to_string(),
 424                                LanguageModelToolResultContent::Image(_) => {
 425                                    "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
 426                                }
 427                            };
 428                            messages.push(mistral::RequestMessage::Tool {
 429                                content: tool_content,
 430                                tool_call_id: tool_result.tool_use_id.to_string(),
 431                            });
 432                        }
 433                    }
 434                }
 435                if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
 436                {
 437                    messages.push(mistral::RequestMessage::User {
 438                        content: message_content,
 439                    });
 440                }
 441            }
 442            Role::Assistant => {
 443                for content in &message.content {
 444                    match content {
 445                        MessageContent::Text(text) => {
 446                            messages.push(mistral::RequestMessage::Assistant {
 447                                content: Some(mistral::MessageContent::Plain {
 448                                    content: text.clone(),
 449                                }),
 450                                tool_calls: Vec::new(),
 451                            });
 452                        }
 453                        MessageContent::Thinking { text, .. } => {
 454                            if model.supports_thinking() {
 455                                messages.push(mistral::RequestMessage::Assistant {
 456                                    content: Some(mistral::MessageContent::Multipart {
 457                                        content: vec![mistral::MessagePart::Thinking {
 458                                            thinking: vec![mistral::ThinkingPart::Text {
 459                                                text: text.clone(),
 460                                            }],
 461                                        }],
 462                                    }),
 463                                    tool_calls: Vec::new(),
 464                                });
 465                            }
 466                        }
 467                        MessageContent::RedactedThinking(_) => {}
 468                        MessageContent::Image(_) => {}
 469                        MessageContent::ToolUse(tool_use) => {
 470                            let tool_call = mistral::ToolCall {
 471                                id: tool_use.id.to_string(),
 472                                content: mistral::ToolCallContent::Function {
 473                                    function: mistral::FunctionContent {
 474                                        name: tool_use.name.to_string(),
 475                                        arguments: serde_json::to_string(&tool_use.input)
 476                                            .unwrap_or_default(),
 477                                    },
 478                                },
 479                            };
 480
 481                            if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
 482                                messages.last_mut()
 483                            {
 484                                tool_calls.push(tool_call);
 485                            } else {
 486                                messages.push(mistral::RequestMessage::Assistant {
 487                                    content: None,
 488                                    tool_calls: vec![tool_call],
 489                                });
 490                            }
 491                        }
 492                        MessageContent::ToolResult(_) => {
 493                            // Tool results are not supported in Assistant messages
 494                        }
 495                    }
 496                }
 497            }
 498            Role::System => {
 499                for content in &message.content {
 500                    match content {
 501                        MessageContent::Text(text) => {
 502                            messages.push(mistral::RequestMessage::System {
 503                                content: mistral::MessageContent::Plain {
 504                                    content: text.clone(),
 505                                },
 506                            });
 507                        }
 508                        MessageContent::Thinking { text, .. } => {
 509                            if model.supports_thinking() {
 510                                messages.push(mistral::RequestMessage::System {
 511                                    content: mistral::MessageContent::Multipart {
 512                                        content: vec![mistral::MessagePart::Thinking {
 513                                            thinking: vec![mistral::ThinkingPart::Text {
 514                                                text: text.clone(),
 515                                            }],
 516                                        }],
 517                                    },
 518                                });
 519                            }
 520                        }
 521                        MessageContent::RedactedThinking(_) => {}
 522                        MessageContent::Image(_)
 523                        | MessageContent::ToolUse(_)
 524                        | MessageContent::ToolResult(_) => {
 525                            // Images and tools are not supported in System messages
 526                        }
 527                    }
 528                }
 529            }
 530        }
 531    }
 532
 533    mistral::Request {
 534        model: model.id().to_string(),
 535        messages,
 536        stream,
 537        max_tokens: max_output_tokens,
 538        temperature: request.temperature,
 539        response_format: None,
 540        tool_choice: match request.tool_choice {
 541            Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
 542                Some(mistral::ToolChoice::Auto)
 543            }
 544            Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
 545                Some(mistral::ToolChoice::Any)
 546            }
 547            Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
 548            _ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
 549            _ => None,
 550        },
 551        parallel_tool_calls: if !request.tools.is_empty() {
 552            Some(false)
 553        } else {
 554            None
 555        },
 556        tools: request
 557            .tools
 558            .into_iter()
 559            .map(|tool| mistral::ToolDefinition::Function {
 560                function: mistral::FunctionDefinition {
 561                    name: tool.name,
 562                    description: Some(tool.description),
 563                    parameters: Some(tool.input_schema),
 564                },
 565            })
 566            .collect(),
 567    }
 568}
 569
 570pub struct MistralEventMapper {
 571    tool_calls_by_index: HashMap<usize, RawToolCall>,
 572}
 573
 574impl MistralEventMapper {
 575    pub fn new() -> Self {
 576        Self {
 577            tool_calls_by_index: HashMap::default(),
 578        }
 579    }
 580
 581    pub fn map_stream(
 582        mut self,
 583        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
 584    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 585    {
 586        events.flat_map(move |event| {
 587            futures::stream::iter(match event {
 588                Ok(event) => self.map_event(event),
 589                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
 590            })
 591        })
 592    }
 593
 594    pub fn map_event(
 595        &mut self,
 596        event: mistral::StreamResponse,
 597    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 598        let Some(choice) = event.choices.first() else {
 599            return vec![Err(LanguageModelCompletionError::from(anyhow!(
 600                "Response contained no choices"
 601            )))];
 602        };
 603
 604        let mut events = Vec::new();
 605        if let Some(content) = choice.delta.content.as_ref() {
 606            match content {
 607                mistral::MessageContentDelta::Text(text) => {
 608                    events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 609                }
 610                mistral::MessageContentDelta::Parts(parts) => {
 611                    for part in parts {
 612                        match part {
 613                            mistral::MessagePart::Text { text } => {
 614                                events.push(Ok(LanguageModelCompletionEvent::Text(text.clone())));
 615                            }
 616                            mistral::MessagePart::Thinking { thinking } => {
 617                                for tp in thinking.iter().cloned() {
 618                                    match tp {
 619                                        mistral::ThinkingPart::Text { text } => {
 620                                            events.push(Ok(
 621                                                LanguageModelCompletionEvent::Thinking {
 622                                                    text,
 623                                                    signature: None,
 624                                                },
 625                                            ));
 626                                        }
 627                                    }
 628                                }
 629                            }
 630                            mistral::MessagePart::ImageUrl { .. } => {
 631                                // We currently don't emit a separate event for images in responses.
 632                            }
 633                        }
 634                    }
 635                }
 636            }
 637        }
 638
 639        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 640            for tool_call in tool_calls {
 641                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 642
 643                if let Some(tool_id) = tool_call.id.clone() {
 644                    entry.id = tool_id;
 645                }
 646
 647                if let Some(function) = tool_call.function.as_ref() {
 648                    if let Some(name) = function.name.clone() {
 649                        entry.name = name;
 650                    }
 651
 652                    if let Some(arguments) = function.arguments.clone() {
 653                        entry.arguments.push_str(&arguments);
 654                    }
 655                }
 656            }
 657        }
 658
 659        if let Some(usage) = event.usage {
 660            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 661                input_tokens: usage.prompt_tokens,
 662                output_tokens: usage.completion_tokens,
 663                cache_creation_input_tokens: 0,
 664                cache_read_input_tokens: 0,
 665            })));
 666        }
 667
 668        if let Some(finish_reason) = choice.finish_reason.as_deref() {
 669            match finish_reason {
 670                "stop" => {
 671                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 672                }
 673                "tool_calls" => {
 674                    events.extend(self.process_tool_calls());
 675                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 676                }
 677                unexpected => {
 678                    log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
 679                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 680                }
 681            }
 682        }
 683
 684        events
 685    }
 686
 687    fn process_tool_calls(
 688        &mut self,
 689    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 690        let mut results = Vec::new();
 691
 692        for (_, tool_call) in self.tool_calls_by_index.drain() {
 693            if tool_call.id.is_empty() || tool_call.name.is_empty() {
 694                results.push(Err(LanguageModelCompletionError::from(anyhow!(
 695                    "Received incomplete tool call: missing id or name"
 696                ))));
 697                continue;
 698            }
 699
 700            match serde_json::Value::from_str(&tool_call.arguments) {
 701                Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
 702                    LanguageModelToolUse {
 703                        id: tool_call.id.into(),
 704                        name: tool_call.name.into(),
 705                        is_input_complete: true,
 706                        input,
 707                        raw_input: tool_call.arguments,
 708                    },
 709                ))),
 710                Err(error) => {
 711                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 712                        id: tool_call.id.into(),
 713                        tool_name: tool_call.name.into(),
 714                        raw_input: tool_call.arguments.into(),
 715                        json_parse_error: error.to_string(),
 716                    }))
 717                }
 718            }
 719        }
 720
 721        results
 722    }
 723}
 724
 725#[derive(Default)]
 726struct RawToolCall {
 727    id: String,
 728    name: String,
 729    arguments: String,
 730}
 731
 732struct ConfigurationView {
 733    api_key_editor: Entity<Editor>,
 734    state: gpui::Entity<State>,
 735    load_credentials_task: Option<Task<()>>,
 736}
 737
 738impl ConfigurationView {
 739    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 740        let api_key_editor = cx.new(|cx| {
 741            let mut editor = Editor::single_line(window, cx);
 742            editor.set_placeholder_text("0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2", cx);
 743            editor
 744        });
 745
 746        cx.observe(&state, |_, _, cx| {
 747            cx.notify();
 748        })
 749        .detach();
 750
 751        let load_credentials_task = Some(cx.spawn_in(window, {
 752            let state = state.clone();
 753            async move |this, cx| {
 754                if let Some(task) = state
 755                    .update(cx, |state, cx| state.authenticate(cx))
 756                    .log_err()
 757                {
 758                    // We don't log an error, because "not signed in" is also an error.
 759                    let _ = task.await;
 760                }
 761
 762                this.update(cx, |this, cx| {
 763                    this.load_credentials_task = None;
 764                    cx.notify();
 765                })
 766                .log_err();
 767            }
 768        }));
 769
 770        Self {
 771            api_key_editor,
 772            state,
 773            load_credentials_task,
 774        }
 775    }
 776
 777    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 778        let api_key = self.api_key_editor.read(cx).text(cx);
 779        if api_key.is_empty() {
 780            return;
 781        }
 782
 783        let state = self.state.clone();
 784        cx.spawn_in(window, async move |_, cx| {
 785            state
 786                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 787                .await
 788        })
 789        .detach_and_log_err(cx);
 790
 791        cx.notify();
 792    }
 793
 794    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 795        self.api_key_editor
 796            .update(cx, |editor, cx| editor.set_text("", window, cx));
 797
 798        let state = self.state.clone();
 799        cx.spawn_in(window, async move |_, cx| {
 800            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 801        })
 802        .detach_and_log_err(cx);
 803
 804        cx.notify();
 805    }
 806
 807    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
 808        let settings = ThemeSettings::get_global(cx);
 809        let text_style = TextStyle {
 810            color: cx.theme().colors().text,
 811            font_family: settings.ui_font.family.clone(),
 812            font_features: settings.ui_font.features.clone(),
 813            font_fallbacks: settings.ui_font.fallbacks.clone(),
 814            font_size: rems(0.875).into(),
 815            font_weight: settings.ui_font.weight,
 816            font_style: FontStyle::Normal,
 817            line_height: relative(1.3),
 818            white_space: WhiteSpace::Normal,
 819            ..Default::default()
 820        };
 821        EditorElement::new(
 822            &self.api_key_editor,
 823            EditorStyle {
 824                background: cx.theme().colors().editor_background,
 825                local_player: cx.theme().players().local(),
 826                text: text_style,
 827                ..Default::default()
 828            },
 829        )
 830    }
 831
 832    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 833        !self.state.read(cx).is_authenticated()
 834    }
 835}
 836
 837impl Render for ConfigurationView {
 838    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 839        let env_var_set = self.state.read(cx).api_key_from_env;
 840
 841        if self.load_credentials_task.is_some() {
 842            div().child(Label::new("Loading credentials...")).into_any()
 843        } else if self.should_render_editor(cx) {
 844            v_flex()
 845                .size_full()
 846                .on_action(cx.listener(Self::save_api_key))
 847                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
 848                .child(
 849                    List::new()
 850                        .child(InstructionListItem::new(
 851                            "Create one by visiting",
 852                            Some("Mistral's console"),
 853                            Some("https://console.mistral.ai/api-keys"),
 854                        ))
 855                        .child(InstructionListItem::text_only(
 856                            "Ensure your Mistral account has credits",
 857                        ))
 858                        .child(InstructionListItem::text_only(
 859                            "Paste your API key below and hit enter to start using the assistant",
 860                        )),
 861                )
 862                .child(
 863                    h_flex()
 864                        .w_full()
 865                        .my_2()
 866                        .px_2()
 867                        .py_1()
 868                        .bg(cx.theme().colors().editor_background)
 869                        .border_1()
 870                        .border_color(cx.theme().colors().border)
 871                        .rounded_sm()
 872                        .child(self.render_api_key_editor(cx)),
 873                )
 874                .child(
 875                    Label::new(
 876                        format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
 877                    )
 878                    .size(LabelSize::Small).color(Color::Muted),
 879                )
 880                .into_any()
 881        } else {
 882            h_flex()
 883                .mt_1()
 884                .p_1()
 885                .justify_between()
 886                .rounded_md()
 887                .border_1()
 888                .border_color(cx.theme().colors().border)
 889                .bg(cx.theme().colors().background)
 890                .child(
 891                    h_flex()
 892                        .gap_1()
 893                        .child(Icon::new(IconName::Check).color(Color::Success))
 894                        .child(Label::new(if env_var_set {
 895                            format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
 896                        } else {
 897                            "API key configured.".to_string()
 898                        })),
 899                )
 900                .child(
 901                    Button::new("reset-key", "Reset Key")
 902                        .label_size(LabelSize::Small)
 903                        .icon(Some(IconName::Trash))
 904                        .icon_size(IconSize::Small)
 905                        .icon_position(IconPosition::Start)
 906                        .disabled(env_var_set)
 907                        .when(env_var_set, |this| {
 908                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
 909                        })
 910                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
 911                )
 912                .into_any()
 913        }
 914    }
 915}
 916
 917#[cfg(test)]
 918mod tests {
 919    use super::*;
 920    use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
 921
 922    #[test]
 923    fn test_into_mistral_basic_conversion() {
 924        let request = LanguageModelRequest {
 925            messages: vec![
 926                LanguageModelRequestMessage {
 927                    role: Role::System,
 928                    content: vec![MessageContent::Text("System prompt".into())],
 929                    cache: false,
 930                },
 931                LanguageModelRequestMessage {
 932                    role: Role::User,
 933                    content: vec![MessageContent::Text("Hello".into())],
 934                    cache: false,
 935                },
 936            ],
 937            temperature: Some(0.5),
 938            tools: vec![],
 939            tool_choice: None,
 940            thread_id: None,
 941            prompt_id: None,
 942            intent: None,
 943            mode: None,
 944            stop: vec![],
 945            thinking_allowed: true,
 946        };
 947
 948        let mistral_request = into_mistral(request, mistral::Model::MistralSmallLatest, None);
 949
 950        assert_eq!(mistral_request.model, "mistral-small-latest");
 951        assert_eq!(mistral_request.temperature, Some(0.5));
 952        assert_eq!(mistral_request.messages.len(), 2);
 953        assert!(mistral_request.stream);
 954    }
 955
 956    #[test]
 957    fn test_into_mistral_with_image() {
 958        let request = LanguageModelRequest {
 959            messages: vec![LanguageModelRequestMessage {
 960                role: Role::User,
 961                content: vec![
 962                    MessageContent::Text("What's in this image?".into()),
 963                    MessageContent::Image(LanguageModelImage {
 964                        source: "base64data".into(),
 965                        size: Default::default(),
 966                    }),
 967                ],
 968                cache: false,
 969            }],
 970            tools: vec![],
 971            tool_choice: None,
 972            temperature: None,
 973            thread_id: None,
 974            prompt_id: None,
 975            intent: None,
 976            mode: None,
 977            stop: vec![],
 978            thinking_allowed: true,
 979        };
 980
 981        let mistral_request = into_mistral(request, mistral::Model::Pixtral12BLatest, None);
 982
 983        assert_eq!(mistral_request.messages.len(), 1);
 984        assert!(matches!(
 985            &mistral_request.messages[0],
 986            mistral::RequestMessage::User {
 987                content: mistral::MessageContent::Multipart { .. }
 988            }
 989        ));
 990
 991        if let mistral::RequestMessage::User {
 992            content: mistral::MessageContent::Multipart { content },
 993        } = &mistral_request.messages[0]
 994        {
 995            assert_eq!(content.len(), 2);
 996            assert!(matches!(
 997                &content[0],
 998                mistral::MessagePart::Text { text } if text == "What's in this image?"
 999            ));
1000            assert!(matches!(
1001                &content[1],
1002                mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
1003            ));
1004        }
1005    }
1006}