language_model: Remove dependencies on individual model provider crates (#25503)

Marshall Bowers created

This PR removes the dependencies on the individual model provider crates
from the `language_model` crate.

The various conversion methods for converting a `LanguageModelRequest`
into its provider-specific request type have been inlined into the
various provider modules in the `language_models` crate.

The model providers we provide via Zed's cloud offering get to stay, for
now.

Release Notes:

- N/A

Change summary

Cargo.lock                                       |   4 
crates/language_model/Cargo.toml                 |   4 
crates/language_model/src/model/mod.rs           |   4 
crates/language_model/src/request.rs             | 292 ------------------
crates/language_model/src/role.rs                |  40 --
crates/language_models/src/provider/anthropic.rs | 119 +++++++
crates/language_models/src/provider/cloud.rs     |  23 
crates/language_models/src/provider/deepseek.rs  | 100 +++++
crates/language_models/src/provider/google.rs    |  36 ++
crates/language_models/src/provider/mistral.rs   |  54 +++
crates/language_models/src/provider/open_ai.rs   |  37 ++
11 files changed, 347 insertions(+), 366 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7015,16 +7015,12 @@ dependencies = [
  "anyhow",
  "base64 0.22.1",
  "collections",
- "deepseek",
  "futures 0.3.31",
  "google_ai",
  "gpui",
  "http_client",
  "image",
- "lmstudio",
  "log",
- "mistral",
- "ollama",
  "open_ai",
  "parking_lot",
  "proto",

crates/language_model/Cargo.toml 🔗

@@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] }
 anyhow.workspace = true
 base64.workspace = true
 collections.workspace = true
-deepseek = { workspace = true, features = ["schemars"] }
 futures.workspace = true
 google_ai = { workspace = true, features = ["schemars"] }
 gpui.workspace = true
 http_client.workspace = true
 image.workspace = true
-lmstudio = { workspace = true, features = ["schemars"] }
 log.workspace = true
-mistral = { workspace = true, features = ["schemars"] }
-ollama = { workspace = true, features = ["schemars"] }
 open_ai = { workspace = true, features = ["schemars"] }
 parking_lot.workspace = true
 proto.workspace = true

crates/language_model/src/model/mod.rs 🔗

@@ -1,7 +1,3 @@
 pub mod cloud_model;
 
-pub use anthropic::Model as AnthropicModel;
 pub use cloud_model::*;
-pub use lmstudio::Model as LmStudioModel;
-pub use ollama::Model as OllamaModel;
-pub use open_ai::Model as OpenAiModel;

crates/language_model/src/request.rs 🔗

@@ -241,298 +241,6 @@ pub struct LanguageModelRequest {
     pub temperature: Option<f32>,
 }
 
-impl LanguageModelRequest {
-    pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
-        let stream = !model.starts_with("o1-");
-        open_ai::Request {
-            model,
-            messages: self
-                .messages
-                .into_iter()
-                .map(|msg| match msg.role {
-                    Role::User => open_ai::RequestMessage::User {
-                        content: msg.string_contents(),
-                    },
-                    Role::Assistant => open_ai::RequestMessage::Assistant {
-                        content: Some(msg.string_contents()),
-                        tool_calls: Vec::new(),
-                    },
-                    Role::System => open_ai::RequestMessage::System {
-                        content: msg.string_contents(),
-                    },
-                })
-                .collect(),
-            stream,
-            stop: self.stop,
-            temperature: self.temperature.unwrap_or(1.0),
-            max_tokens: max_output_tokens,
-            tools: Vec::new(),
-            tool_choice: None,
-        }
-    }
-
-    pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
-        let len = self.messages.len();
-        let merged_messages =
-            self.messages
-                .into_iter()
-                .fold(Vec::with_capacity(len), |mut acc, msg| {
-                    let role = msg.role;
-                    let content = msg.string_contents();
-
-                    acc.push(match role {
-                        Role::User => mistral::RequestMessage::User { content },
-                        Role::Assistant => mistral::RequestMessage::Assistant {
-                            content: Some(content),
-                            tool_calls: Vec::new(),
-                        },
-                        Role::System => mistral::RequestMessage::System { content },
-                    });
-                    acc
-                });
-
-        mistral::Request {
-            model,
-            messages: merged_messages,
-            stream: true,
-            max_tokens: max_output_tokens,
-            temperature: self.temperature,
-            response_format: None,
-            tools: self
-                .tools
-                .into_iter()
-                .map(|tool| mistral::ToolDefinition::Function {
-                    function: mistral::FunctionDefinition {
-                        name: tool.name,
-                        description: Some(tool.description),
-                        parameters: Some(tool.input_schema),
-                    },
-                })
-                .collect(),
-        }
-    }
-
-    pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
-        google_ai::GenerateContentRequest {
-            model,
-            contents: self
-                .messages
-                .into_iter()
-                .map(|msg| google_ai::Content {
-                    parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
-                        text: msg.string_contents(),
-                    })],
-                    role: match msg.role {
-                        Role::User => google_ai::Role::User,
-                        Role::Assistant => google_ai::Role::Model,
-                        Role::System => google_ai::Role::User, // Google AI doesn't have a system role
-                    },
-                })
-                .collect(),
-            generation_config: Some(google_ai::GenerationConfig {
-                candidate_count: Some(1),
-                stop_sequences: Some(self.stop),
-                max_output_tokens: None,
-                temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
-                top_p: None,
-                top_k: None,
-            }),
-            safety_settings: None,
-        }
-    }
-
-    pub fn into_anthropic(
-        self,
-        model: String,
-        default_temperature: f32,
-        max_output_tokens: u32,
-    ) -> anthropic::Request {
-        let mut new_messages: Vec<anthropic::Message> = Vec::new();
-        let mut system_message = String::new();
-
-        for message in self.messages {
-            if message.contents_empty() {
-                continue;
-            }
-
-            match message.role {
-                Role::User | Role::Assistant => {
-                    let cache_control = if message.cache {
-                        Some(anthropic::CacheControl {
-                            cache_type: anthropic::CacheControlType::Ephemeral,
-                        })
-                    } else {
-                        None
-                    };
-                    let anthropic_message_content: Vec<anthropic::RequestContent> = message
-                        .content
-                        .into_iter()
-                        .filter_map(|content| match content {
-                            MessageContent::Text(text) => {
-                                if !text.is_empty() {
-                                    Some(anthropic::RequestContent::Text {
-                                        text,
-                                        cache_control,
-                                    })
-                                } else {
-                                    None
-                                }
-                            }
-                            MessageContent::Image(image) => {
-                                Some(anthropic::RequestContent::Image {
-                                    source: anthropic::ImageSource {
-                                        source_type: "base64".to_string(),
-                                        media_type: "image/png".to_string(),
-                                        data: image.source.to_string(),
-                                    },
-                                    cache_control,
-                                })
-                            }
-                            MessageContent::ToolUse(tool_use) => {
-                                Some(anthropic::RequestContent::ToolUse {
-                                    id: tool_use.id.to_string(),
-                                    name: tool_use.name,
-                                    input: tool_use.input,
-                                    cache_control,
-                                })
-                            }
-                            MessageContent::ToolResult(tool_result) => {
-                                Some(anthropic::RequestContent::ToolResult {
-                                    tool_use_id: tool_result.tool_use_id,
-                                    is_error: tool_result.is_error,
-                                    content: tool_result.content,
-                                    cache_control,
-                                })
-                            }
-                        })
-                        .collect();
-                    let anthropic_role = match message.role {
-                        Role::User => anthropic::Role::User,
-                        Role::Assistant => anthropic::Role::Assistant,
-                        Role::System => unreachable!("System role should never occur here"),
-                    };
-                    if let Some(last_message) = new_messages.last_mut() {
-                        if last_message.role == anthropic_role {
-                            last_message.content.extend(anthropic_message_content);
-                            continue;
-                        }
-                    }
-                    new_messages.push(anthropic::Message {
-                        role: anthropic_role,
-                        content: anthropic_message_content,
-                    });
-                }
-                Role::System => {
-                    if !system_message.is_empty() {
-                        system_message.push_str("\n\n");
-                    }
-                    system_message.push_str(&message.string_contents());
-                }
-            }
-        }
-
-        anthropic::Request {
-            model,
-            messages: new_messages,
-            max_tokens: max_output_tokens,
-            system: Some(system_message),
-            tools: self
-                .tools
-                .into_iter()
-                .map(|tool| anthropic::Tool {
-                    name: tool.name,
-                    description: tool.description,
-                    input_schema: tool.input_schema,
-                })
-                .collect(),
-            tool_choice: None,
-            metadata: None,
-            stop_sequences: Vec::new(),
-            temperature: self.temperature.or(Some(default_temperature)),
-            top_k: None,
-            top_p: None,
-        }
-    }
-
-    pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
-        let is_reasoner = model == "deepseek-reasoner";
-
-        let len = self.messages.len();
-        let merged_messages =
-            self.messages
-                .into_iter()
-                .fold(Vec::with_capacity(len), |mut acc, msg| {
-                    let role = msg.role;
-                    let content = msg.string_contents();
-
-                    if is_reasoner {
-                        if let Some(last_msg) = acc.last_mut() {
-                            match (last_msg, role) {
-                                (deepseek::RequestMessage::User { content: last }, Role::User) => {
-                                    last.push(' ');
-                                    last.push_str(&content);
-                                    return acc;
-                                }
-
-                                (
-                                    deepseek::RequestMessage::Assistant {
-                                        content: last_content,
-                                        ..
-                                    },
-                                    Role::Assistant,
-                                ) => {
-                                    *last_content = last_content
-                                        .take()
-                                        .map(|c| {
-                                            let mut s =
-                                                String::with_capacity(c.len() + content.len() + 1);
-                                            s.push_str(&c);
-                                            s.push(' ');
-                                            s.push_str(&content);
-                                            s
-                                        })
-                                        .or(Some(content));
-
-                                    return acc;
-                                }
-                                _ => {}
-                            }
-                        }
-                    }
-
-                    acc.push(match role {
-                        Role::User => deepseek::RequestMessage::User { content },
-                        Role::Assistant => deepseek::RequestMessage::Assistant {
-                            content: Some(content),
-                            tool_calls: Vec::new(),
-                        },
-                        Role::System => deepseek::RequestMessage::System { content },
-                    });
-                    acc
-                });
-
-        deepseek::Request {
-            model,
-            messages: merged_messages,
-            stream: true,
-            max_tokens: max_output_tokens,
-            temperature: if is_reasoner { None } else { self.temperature },
-            response_format: None,
-            tools: self
-                .tools
-                .into_iter()
-                .map(|tool| deepseek::ToolDefinition::Function {
-                    function: deepseek::FunctionDefinition {
-                        name: tool.name,
-                        description: Some(tool.description),
-                        parameters: Some(tool.input_schema),
-                    },
-                })
-                .collect(),
-        }
-    }
-}
-
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 pub struct LanguageModelResponseMessage {
     pub role: Option<Role>,

crates/language_model/src/role.rs 🔗

@@ -45,43 +45,3 @@ impl Display for Role {
         }
     }
 }
-
-impl From<Role> for ollama::Role {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => ollama::Role::User,
-            Role::Assistant => ollama::Role::Assistant,
-            Role::System => ollama::Role::System,
-        }
-    }
-}
-
-impl From<Role> for open_ai::Role {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => open_ai::Role::User,
-            Role::Assistant => open_ai::Role::Assistant,
-            Role::System => open_ai::Role::System,
-        }
-    }
-}
-
-impl From<Role> for deepseek::Role {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => deepseek::Role::User,
-            Role::Assistant => deepseek::Role::Assistant,
-            Role::System => deepseek::Role::System,
-        }
-    }
-}
-
-impl From<Role> for lmstudio::Role {
-    fn from(val: Role) -> Self {
-        match val {
-            Role::User => lmstudio::Role::User,
-            Role::Assistant => lmstudio::Role::Assistant,
-            Role::System => lmstudio::Role::System,
-        }
-    }
-}

crates/language_models/src/provider/anthropic.rs 🔗

@@ -13,7 +13,7 @@ use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
     LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+    LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
 };
 use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
 use schemars::JsonSchema;
@@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel {
         request: LanguageModelRequest,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
-        let request = request.into_anthropic(
+        let request = into_anthropic(
+            request,
             self.model.id().into(),
             self.model.default_temperature(),
             self.model.max_output_tokens(),
@@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel {
         input_schema: serde_json::Value,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let mut request = request.into_anthropic(
+        let mut request = into_anthropic(
+            request,
             self.model.tool_model_id().into(),
             self.model.default_temperature(),
             self.model.max_output_tokens(),
@@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel {
     }
 }
 
+pub fn into_anthropic(
+    request: LanguageModelRequest,
+    model: String,
+    default_temperature: f32,
+    max_output_tokens: u32,
+) -> anthropic::Request {
+    let mut new_messages: Vec<anthropic::Message> = Vec::new();
+    let mut system_message = String::new();
+
+    for message in request.messages {
+        if message.contents_empty() {
+            continue;
+        }
+
+        match message.role {
+            Role::User | Role::Assistant => {
+                let cache_control = if message.cache {
+                    Some(anthropic::CacheControl {
+                        cache_type: anthropic::CacheControlType::Ephemeral,
+                    })
+                } else {
+                    None
+                };
+                let anthropic_message_content: Vec<anthropic::RequestContent> = message
+                    .content
+                    .into_iter()
+                    .filter_map(|content| match content {
+                        MessageContent::Text(text) => {
+                            if !text.is_empty() {
+                                Some(anthropic::RequestContent::Text {
+                                    text,
+                                    cache_control,
+                                })
+                            } else {
+                                None
+                            }
+                        }
+                        MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
+                            source: anthropic::ImageSource {
+                                source_type: "base64".to_string(),
+                                media_type: "image/png".to_string(),
+                                data: image.source.to_string(),
+                            },
+                            cache_control,
+                        }),
+                        MessageContent::ToolUse(tool_use) => {
+                            Some(anthropic::RequestContent::ToolUse {
+                                id: tool_use.id.to_string(),
+                                name: tool_use.name,
+                                input: tool_use.input,
+                                cache_control,
+                            })
+                        }
+                        MessageContent::ToolResult(tool_result) => {
+                            Some(anthropic::RequestContent::ToolResult {
+                                tool_use_id: tool_result.tool_use_id,
+                                is_error: tool_result.is_error,
+                                content: tool_result.content,
+                                cache_control,
+                            })
+                        }
+                    })
+                    .collect();
+                let anthropic_role = match message.role {
+                    Role::User => anthropic::Role::User,
+                    Role::Assistant => anthropic::Role::Assistant,
+                    Role::System => unreachable!("System role should never occur here"),
+                };
+                if let Some(last_message) = new_messages.last_mut() {
+                    if last_message.role == anthropic_role {
+                        last_message.content.extend(anthropic_message_content);
+                        continue;
+                    }
+                }
+                new_messages.push(anthropic::Message {
+                    role: anthropic_role,
+                    content: anthropic_message_content,
+                });
+            }
+            Role::System => {
+                if !system_message.is_empty() {
+                    system_message.push_str("\n\n");
+                }
+                system_message.push_str(&message.string_contents());
+            }
+        }
+    }
+
+    anthropic::Request {
+        model,
+        messages: new_messages,
+        max_tokens: max_output_tokens,
+        system: Some(system_message),
+        tools: request
+            .tools
+            .into_iter()
+            .map(|tool| anthropic::Tool {
+                name: tool.name,
+                description: tool.description,
+                input_schema: tool.input_schema,
+            })
+            .collect(),
+        tool_choice: None,
+        metadata: None,
+        stop_sequences: Vec::new(),
+        temperature: request.temperature.or(Some(default_temperature)),
+        top_k: None,
+        top_p: None,
+    }
+}
+
 pub fn map_to_language_model_completion_events(
     events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {

crates/language_models/src/provider/cloud.rs 🔗

@@ -1,4 +1,3 @@
-use super::open_ai::count_open_ai_tokens;
 use anthropic::AnthropicError;
 use anyhow::{anyhow, Result};
 use client::{
@@ -43,11 +42,13 @@ use strum::IntoEnumIterator;
 use thiserror::Error;
 use ui::{prelude::*, TintColor};
 
-use crate::provider::anthropic::map_to_language_model_completion_events;
+use crate::provider::anthropic::{
+    count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
+};
+use crate::provider::google::into_google;
+use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
 use crate::AllLanguageModelSettings;
 
-use super::anthropic::count_anthropic_tokens;
-
 pub const PROVIDER_NAME: &str = "Zed";
 
 const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
@@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
             CloudModel::Google(model) => {
                 let client = self.client.clone();
-                let request = request.into_google(model.id().into());
+                let request = into_google(request, model.id().into());
                 let request = google_ai::CountTokensRequest {
                     contents: request.contents,
                 };
@@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel {
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let request = request.into_anthropic(
+                let request = into_anthropic(
+                    request,
                     model.id().into(),
                     model.default_temperature(),
                     model.max_output_tokens(),
@@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel {
             }
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();
-                let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
+                let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream(async move {
                     let response = Self::perform_llm_completion(
@@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel {
             }
             CloudModel::Google(model) => {
                 let client = self.client.clone();
-                let request = request.into_google(model.id().into());
+                let request = into_google(request, model.id().into());
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream(async move {
                     let response = Self::perform_llm_completion(
@@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel {
 
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let mut request = request.into_anthropic(
+                let mut request = into_anthropic(
+                    request,
                     model.tool_model_id().into(),
                     model.default_temperature(),
                     model.max_output_tokens(),
@@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel {
             }
             CloudModel::OpenAi(model) => {
                 let mut request =
-                    request.into_open_ai(model.id().into(), model.max_output_tokens());
+                    into_open_ai(request, model.id().into(), model.max_output_tokens());
                 request.tool_choice = Some(open_ai::ToolChoice::Other(
                     open_ai::ToolDefinition::Function {
                         function: open_ai::FunctionDefinition {

crates/language_models/src/provider/deepseek.rs 🔗

@@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel {
         request: LanguageModelRequest,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
-        let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
+        let request = into_deepseek(
+            request,
+            self.model.id().to_string(),
+            self.max_output_tokens(),
+        );
         let stream = self.stream_completion(request, cx);
 
         async move {
@@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel {
         schema: serde_json::Value,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
-        let mut deepseek_request =
-            request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
+        let mut deepseek_request = into_deepseek(
+            request,
+            self.model.id().to_string(),
+            self.max_output_tokens(),
+        );
 
         deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
             function: deepseek::FunctionDefinition {
@@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel {
     }
 }
 
+pub fn into_deepseek(
+    request: LanguageModelRequest,
+    model: String,
+    max_output_tokens: Option<u32>,
+) -> deepseek::Request {
+    let is_reasoner = model == "deepseek-reasoner";
+
+    let len = request.messages.len();
+    let merged_messages =
+        request
+            .messages
+            .into_iter()
+            .fold(Vec::with_capacity(len), |mut acc, msg| {
+                let role = msg.role;
+                let content = msg.string_contents();
+
+                if is_reasoner {
+                    if let Some(last_msg) = acc.last_mut() {
+                        match (last_msg, role) {
+                            (deepseek::RequestMessage::User { content: last }, Role::User) => {
+                                last.push(' ');
+                                last.push_str(&content);
+                                return acc;
+                            }
+
+                            (
+                                deepseek::RequestMessage::Assistant {
+                                    content: last_content,
+                                    ..
+                                },
+                                Role::Assistant,
+                            ) => {
+                                *last_content = last_content
+                                    .take()
+                                    .map(|c| {
+                                        let mut s =
+                                            String::with_capacity(c.len() + content.len() + 1);
+                                        s.push_str(&c);
+                                        s.push(' ');
+                                        s.push_str(&content);
+                                        s
+                                    })
+                                    .or(Some(content));
+
+                                return acc;
+                            }
+                            _ => {}
+                        }
+                    }
+                }
+
+                acc.push(match role {
+                    Role::User => deepseek::RequestMessage::User { content },
+                    Role::Assistant => deepseek::RequestMessage::Assistant {
+                        content: Some(content),
+                        tool_calls: Vec::new(),
+                    },
+                    Role::System => deepseek::RequestMessage::System { content },
+                });
+                acc
+            });
+
+    deepseek::Request {
+        model,
+        messages: merged_messages,
+        stream: true,
+        max_tokens: max_output_tokens,
+        temperature: if is_reasoner {
+            None
+        } else {
+            request.temperature
+        },
+        response_format: None,
+        tools: request
+            .tools
+            .into_iter()
+            .map(|tool| deepseek::ToolDefinition::Function {
+                function: deepseek::FunctionDefinition {
+                    name: tool.name,
+                    description: Some(tool.description),
+                    parameters: Some(tool.input_schema),
+                },
+            })
+            .collect(),
+    }
+}
+
 struct ConfigurationView {
     api_key_editor: Entity<Editor>,
     state: Entity<State>,

crates/language_models/src/provider/google.rs 🔗

@@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel {
         request: LanguageModelRequest,
         cx: &App,
     ) -> BoxFuture<'static, Result<usize>> {
-        let request = request.into_google(self.model.id().to_string());
+        let request = into_google(request, self.model.id().to_string());
         let http_client = self.http_client.clone();
         let api_key = self.state.read(cx).api_key.clone();
 
@@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel {
         'static,
         Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
     > {
-        let request = request.into_google(self.model.id().to_string());
+        let request = into_google(request, self.model.id().to_string());
 
         let http_client = self.http_client.clone();
         let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
@@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel {
     }
 }
 
+pub fn into_google(
+    request: LanguageModelRequest,
+    model: String,
+) -> google_ai::GenerateContentRequest {
+    google_ai::GenerateContentRequest {
+        model,
+        contents: request
+            .messages
+            .into_iter()
+            .map(|msg| google_ai::Content {
+                parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
+                    text: msg.string_contents(),
+                })],
+                role: match msg.role {
+                    Role::User => google_ai::Role::User,
+                    Role::Assistant => google_ai::Role::Model,
+                    Role::System => google_ai::Role::User, // Google AI doesn't have a system role
+                },
+            })
+            .collect(),
+        generation_config: Some(google_ai::GenerationConfig {
+            candidate_count: Some(1),
+            stop_sequences: Some(request.stop),
+            max_output_tokens: None,
+            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+            top_p: None,
+            top_k: None,
+        }),
+        safety_settings: None,
+    }
+}
+
 pub fn count_google_tokens(
     request: LanguageModelRequest,
     cx: &App,

crates/language_models/src/provider/mistral.rs 🔗

@@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel {
         request: LanguageModelRequest,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
-        let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
+        let request = into_mistral(
+            request,
+            self.model.id().to_string(),
+            self.max_output_tokens(),
+        );
         let stream = self.stream_completion(request, cx);
 
         async move {
@@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel {
         schema: serde_json::Value,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
-        let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
+        let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens());
         request.tools = vec![mistral::ToolDefinition::Function {
             function: mistral::FunctionDefinition {
                 name: tool_name.clone(),
@@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel {
     }
 }
 
+pub fn into_mistral(
+    request: LanguageModelRequest,
+    model: String,
+    max_output_tokens: Option<u32>,
+) -> mistral::Request {
+    let len = request.messages.len();
+    let merged_messages =
+        request
+            .messages
+            .into_iter()
+            .fold(Vec::with_capacity(len), |mut acc, msg| {
+                let role = msg.role;
+                let content = msg.string_contents();
+
+                acc.push(match role {
+                    Role::User => mistral::RequestMessage::User { content },
+                    Role::Assistant => mistral::RequestMessage::Assistant {
+                        content: Some(content),
+                        tool_calls: Vec::new(),
+                    },
+                    Role::System => mistral::RequestMessage::System { content },
+                });
+                acc
+            });
+
+    mistral::Request {
+        model,
+        messages: merged_messages,
+        stream: true,
+        max_tokens: max_output_tokens,
+        temperature: request.temperature,
+        response_format: None,
+        tools: request
+            .tools
+            .into_iter()
+            .map(|tool| mistral::ToolDefinition::Function {
+                function: mistral::FunctionDefinition {
+                    name: tool.name,
+                    description: Some(tool.description),
+                    parameters: Some(tool.input_schema),
+                },
+            })
+            .collect(),
+    }
+}
+
 struct ConfigurationView {
     api_key_editor: Entity<Editor>,
     state: gpui::Entity<State>,

crates/language_models/src/provider/open_ai.rs 🔗

@@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel {
         'static,
         Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
     > {
-        let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
+        let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
         let completions = self.stream_completion(request, cx);
         async move {
             Ok(open_ai::extract_text_from_events(completions.await?)
@@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel {
         schema: serde_json::Value,
         cx: &AsyncApp,
     ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
-        let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
+        let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
         request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
             function: FunctionDefinition {
                 name: tool_name.clone(),
@@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel {
     }
 }
 
+pub fn into_open_ai(
+    request: LanguageModelRequest,
+    model: String,
+    max_output_tokens: Option<u32>,
+) -> open_ai::Request {
+    let stream = !model.starts_with("o1-");
+    open_ai::Request {
+        model,
+        messages: request
+            .messages
+            .into_iter()
+            .map(|msg| match msg.role {
+                Role::User => open_ai::RequestMessage::User {
+                    content: msg.string_contents(),
+                },
+                Role::Assistant => open_ai::RequestMessage::Assistant {
+                    content: Some(msg.string_contents()),
+                    tool_calls: Vec::new(),
+                },
+                Role::System => open_ai::RequestMessage::System {
+                    content: msg.string_contents(),
+                },
+            })
+            .collect(),
+        stream,
+        stop: request.stop,
+        temperature: request.temperature.unwrap_or(1.0),
+        max_tokens: max_output_tokens,
+        tools: Vec::new(),
+        tool_choice: None,
+    }
+}
+
 pub fn count_open_ai_tokens(
     request: LanguageModelRequest,
     model: open_ai::Model,