Detailed changes
@@ -7015,16 +7015,12 @@ dependencies = [
"anyhow",
"base64 0.22.1",
"collections",
- "deepseek",
"futures 0.3.31",
"google_ai",
"gpui",
"http_client",
"image",
- "lmstudio",
"log",
- "mistral",
- "ollama",
"open_ai",
"parking_lot",
"proto",
@@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
base64.workspace = true
collections.workspace = true
-deepseek = { workspace = true, features = ["schemars"] }
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
image.workspace = true
-lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
-mistral = { workspace = true, features = ["schemars"] }
-ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
proto.workspace = true
@@ -1,7 +1,3 @@
pub mod cloud_model;
-pub use anthropic::Model as AnthropicModel;
pub use cloud_model::*;
-pub use lmstudio::Model as LmStudioModel;
-pub use ollama::Model as OllamaModel;
-pub use open_ai::Model as OpenAiModel;
@@ -241,298 +241,6 @@ pub struct LanguageModelRequest {
pub temperature: Option<f32>,
}
-impl LanguageModelRequest {
- pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
- let stream = !model.starts_with("o1-");
- open_ai::Request {
- model,
- messages: self
- .messages
- .into_iter()
- .map(|msg| match msg.role {
- Role::User => open_ai::RequestMessage::User {
- content: msg.string_contents(),
- },
- Role::Assistant => open_ai::RequestMessage::Assistant {
- content: Some(msg.string_contents()),
- tool_calls: Vec::new(),
- },
- Role::System => open_ai::RequestMessage::System {
- content: msg.string_contents(),
- },
- })
- .collect(),
- stream,
- stop: self.stop,
- temperature: self.temperature.unwrap_or(1.0),
- max_tokens: max_output_tokens,
- tools: Vec::new(),
- tool_choice: None,
- }
- }
-
- pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
- let len = self.messages.len();
- let merged_messages =
- self.messages
- .into_iter()
- .fold(Vec::with_capacity(len), |mut acc, msg| {
- let role = msg.role;
- let content = msg.string_contents();
-
- acc.push(match role {
- Role::User => mistral::RequestMessage::User { content },
- Role::Assistant => mistral::RequestMessage::Assistant {
- content: Some(content),
- tool_calls: Vec::new(),
- },
- Role::System => mistral::RequestMessage::System { content },
- });
- acc
- });
-
- mistral::Request {
- model,
- messages: merged_messages,
- stream: true,
- max_tokens: max_output_tokens,
- temperature: self.temperature,
- response_format: None,
- tools: self
- .tools
- .into_iter()
- .map(|tool| mistral::ToolDefinition::Function {
- function: mistral::FunctionDefinition {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- },
- })
- .collect(),
- }
- }
-
- pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
- google_ai::GenerateContentRequest {
- model,
- contents: self
- .messages
- .into_iter()
- .map(|msg| google_ai::Content {
- parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
- text: msg.string_contents(),
- })],
- role: match msg.role {
- Role::User => google_ai::Role::User,
- Role::Assistant => google_ai::Role::Model,
- Role::System => google_ai::Role::User, // Google AI doesn't have a system role
- },
- })
- .collect(),
- generation_config: Some(google_ai::GenerationConfig {
- candidate_count: Some(1),
- stop_sequences: Some(self.stop),
- max_output_tokens: None,
- temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
- top_p: None,
- top_k: None,
- }),
- safety_settings: None,
- }
- }
-
- pub fn into_anthropic(
- self,
- model: String,
- default_temperature: f32,
- max_output_tokens: u32,
- ) -> anthropic::Request {
- let mut new_messages: Vec<anthropic::Message> = Vec::new();
- let mut system_message = String::new();
-
- for message in self.messages {
- if message.contents_empty() {
- continue;
- }
-
- match message.role {
- Role::User | Role::Assistant => {
- let cache_control = if message.cache {
- Some(anthropic::CacheControl {
- cache_type: anthropic::CacheControlType::Ephemeral,
- })
- } else {
- None
- };
- let anthropic_message_content: Vec<anthropic::RequestContent> = message
- .content
- .into_iter()
- .filter_map(|content| match content {
- MessageContent::Text(text) => {
- if !text.is_empty() {
- Some(anthropic::RequestContent::Text {
- text,
- cache_control,
- })
- } else {
- None
- }
- }
- MessageContent::Image(image) => {
- Some(anthropic::RequestContent::Image {
- source: anthropic::ImageSource {
- source_type: "base64".to_string(),
- media_type: "image/png".to_string(),
- data: image.source.to_string(),
- },
- cache_control,
- })
- }
- MessageContent::ToolUse(tool_use) => {
- Some(anthropic::RequestContent::ToolUse {
- id: tool_use.id.to_string(),
- name: tool_use.name,
- input: tool_use.input,
- cache_control,
- })
- }
- MessageContent::ToolResult(tool_result) => {
- Some(anthropic::RequestContent::ToolResult {
- tool_use_id: tool_result.tool_use_id,
- is_error: tool_result.is_error,
- content: tool_result.content,
- cache_control,
- })
- }
- })
- .collect();
- let anthropic_role = match message.role {
- Role::User => anthropic::Role::User,
- Role::Assistant => anthropic::Role::Assistant,
- Role::System => unreachable!("System role should never occur here"),
- };
- if let Some(last_message) = new_messages.last_mut() {
- if last_message.role == anthropic_role {
- last_message.content.extend(anthropic_message_content);
- continue;
- }
- }
- new_messages.push(anthropic::Message {
- role: anthropic_role,
- content: anthropic_message_content,
- });
- }
- Role::System => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.string_contents());
- }
- }
- }
-
- anthropic::Request {
- model,
- messages: new_messages,
- max_tokens: max_output_tokens,
- system: Some(system_message),
- tools: self
- .tools
- .into_iter()
- .map(|tool| anthropic::Tool {
- name: tool.name,
- description: tool.description,
- input_schema: tool.input_schema,
- })
- .collect(),
- tool_choice: None,
- metadata: None,
- stop_sequences: Vec::new(),
- temperature: self.temperature.or(Some(default_temperature)),
- top_k: None,
- top_p: None,
- }
- }
-
- pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
- let is_reasoner = model == "deepseek-reasoner";
-
- let len = self.messages.len();
- let merged_messages =
- self.messages
- .into_iter()
- .fold(Vec::with_capacity(len), |mut acc, msg| {
- let role = msg.role;
- let content = msg.string_contents();
-
- if is_reasoner {
- if let Some(last_msg) = acc.last_mut() {
- match (last_msg, role) {
- (deepseek::RequestMessage::User { content: last }, Role::User) => {
- last.push(' ');
- last.push_str(&content);
- return acc;
- }
-
- (
- deepseek::RequestMessage::Assistant {
- content: last_content,
- ..
- },
- Role::Assistant,
- ) => {
- *last_content = last_content
- .take()
- .map(|c| {
- let mut s =
- String::with_capacity(c.len() + content.len() + 1);
- s.push_str(&c);
- s.push(' ');
- s.push_str(&content);
- s
- })
- .or(Some(content));
-
- return acc;
- }
- _ => {}
- }
- }
- }
-
- acc.push(match role {
- Role::User => deepseek::RequestMessage::User { content },
- Role::Assistant => deepseek::RequestMessage::Assistant {
- content: Some(content),
- tool_calls: Vec::new(),
- },
- Role::System => deepseek::RequestMessage::System { content },
- });
- acc
- });
-
- deepseek::Request {
- model,
- messages: merged_messages,
- stream: true,
- max_tokens: max_output_tokens,
- temperature: if is_reasoner { None } else { self.temperature },
- response_format: None,
- tools: self
- .tools
- .into_iter()
- .map(|tool| deepseek::ToolDefinition::Function {
- function: deepseek::FunctionDefinition {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- },
- })
- .collect(),
- }
- }
-}
-
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
@@ -45,43 +45,3 @@ impl Display for Role {
}
}
}
-
-impl From<Role> for ollama::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => ollama::Role::User,
- Role::Assistant => ollama::Role::Assistant,
- Role::System => ollama::Role::System,
- }
- }
-}
-
-impl From<Role> for open_ai::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => open_ai::Role::User,
- Role::Assistant => open_ai::Role::Assistant,
- Role::System => open_ai::Role::System,
- }
- }
-}
-
-impl From<Role> for deepseek::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => deepseek::Role::User,
- Role::Assistant => deepseek::Role::Assistant,
- Role::System => deepseek::Role::System,
- }
- }
-}
-
-impl From<Role> for lmstudio::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => lmstudio::Role::User,
- Role::Assistant => lmstudio::Role::Assistant,
- Role::System => lmstudio::Role::System,
- }
- }
-}
@@ -13,7 +13,7 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
- let request = request.into_anthropic(
+ let request = into_anthropic(
+ request,
self.model.id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
@@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let mut request = request.into_anthropic(
+ let mut request = into_anthropic(
+ request,
self.model.tool_model_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
@@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel {
}
}
+pub fn into_anthropic(
+ request: LanguageModelRequest,
+ model: String,
+ default_temperature: f32,
+ max_output_tokens: u32,
+) -> anthropic::Request {
+ let mut new_messages: Vec<anthropic::Message> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in request.messages {
+ if message.contents_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ let cache_control = if message.cache {
+ Some(anthropic::CacheControl {
+ cache_type: anthropic::CacheControlType::Ephemeral,
+ })
+ } else {
+ None
+ };
+ let anthropic_message_content: Vec<anthropic::RequestContent> = message
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ MessageContent::Text(text) => {
+ if !text.is_empty() {
+ Some(anthropic::RequestContent::Text {
+ text,
+ cache_control,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ cache_control,
+ }),
+ MessageContent::ToolUse(tool_use) => {
+ Some(anthropic::RequestContent::ToolUse {
+ id: tool_use.id.to_string(),
+ name: tool_use.name,
+ input: tool_use.input,
+ cache_control,
+ })
+ }
+ MessageContent::ToolResult(tool_result) => {
+ Some(anthropic::RequestContent::ToolResult {
+ tool_use_id: tool_result.tool_use_id,
+ is_error: tool_result.is_error,
+ content: tool_result.content,
+ cache_control,
+ })
+ }
+ })
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => anthropic::Role::User,
+ Role::Assistant => anthropic::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
+ if let Some(last_message) = new_messages.last_mut() {
+ if last_message.role == anthropic_role {
+ last_message.content.extend(anthropic_message_content);
+ continue;
+ }
+ }
+ new_messages.push(anthropic::Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
+ });
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.string_contents());
+ }
+ }
+ }
+
+ anthropic::Request {
+ model,
+ messages: new_messages,
+ max_tokens: max_output_tokens,
+ system: Some(system_message),
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| anthropic::Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ })
+ .collect(),
+ tool_choice: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: request.temperature.or(Some(default_temperature)),
+ top_k: None,
+ top_p: None,
+ }
+}
+
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
@@ -1,4 +1,3 @@
-use super::open_ai::count_open_ai_tokens;
use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
use client::{
@@ -43,11 +42,13 @@ use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{prelude::*, TintColor};
-use crate::provider::anthropic::map_to_language_model_completion_events;
+use crate::provider::anthropic::{
+ count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
+};
+use crate::provider::google::into_google;
+use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
use crate::AllLanguageModelSettings;
-use super::anthropic::count_anthropic_tokens;
-
pub const PROVIDER_NAME: &str = "Zed";
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
@@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
CloudModel::Google(model) => {
let client = self.client.clone();
- let request = request.into_google(model.id().into());
+ let request = into_google(request, model.id().into());
let request = google_ai::CountTokensRequest {
contents: request.contents,
};
@@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
- let request = request.into_anthropic(
+ let request = into_anthropic(
+ request,
model.id().into(),
model.default_temperature(),
model.max_output_tokens(),
@@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
- let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
+ let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Google(model) => {
let client = self.client.clone();
- let request = request.into_google(model.id().into());
+ let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
- let mut request = request.into_anthropic(
+ let mut request = into_anthropic(
+ request,
model.tool_model_id().into(),
model.default_temperature(),
model.max_output_tokens(),
@@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let mut request =
- request.into_open_ai(model.id().into(), model.max_output_tokens());
+ into_open_ai(request, model.id().into(), model.max_output_tokens());
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
@@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
- let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
+ let request = into_deepseek(
+ request,
+ self.model.id().to_string(),
+ self.max_output_tokens(),
+ );
let stream = self.stream_completion(request, cx);
async move {
@@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let mut deepseek_request =
- request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
+ let mut deepseek_request = into_deepseek(
+ request,
+ self.model.id().to_string(),
+ self.max_output_tokens(),
+ );
deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
function: deepseek::FunctionDefinition {
@@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel {
}
}
+pub fn into_deepseek(
+ request: LanguageModelRequest,
+ model: String,
+ max_output_tokens: Option<u32>,
+) -> deepseek::Request {
+ let is_reasoner = model == "deepseek-reasoner";
+
+ let len = request.messages.len();
+ let merged_messages =
+ request
+ .messages
+ .into_iter()
+ .fold(Vec::with_capacity(len), |mut acc, msg| {
+ let role = msg.role;
+ let content = msg.string_contents();
+
+ if is_reasoner {
+ if let Some(last_msg) = acc.last_mut() {
+ match (last_msg, role) {
+ (deepseek::RequestMessage::User { content: last }, Role::User) => {
+ last.push(' ');
+ last.push_str(&content);
+ return acc;
+ }
+
+ (
+ deepseek::RequestMessage::Assistant {
+ content: last_content,
+ ..
+ },
+ Role::Assistant,
+ ) => {
+ *last_content = last_content
+ .take()
+ .map(|c| {
+ let mut s =
+ String::with_capacity(c.len() + content.len() + 1);
+ s.push_str(&c);
+ s.push(' ');
+ s.push_str(&content);
+ s
+ })
+ .or(Some(content));
+
+ return acc;
+ }
+ _ => {}
+ }
+ }
+ }
+
+ acc.push(match role {
+ Role::User => deepseek::RequestMessage::User { content },
+ Role::Assistant => deepseek::RequestMessage::Assistant {
+ content: Some(content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => deepseek::RequestMessage::System { content },
+ });
+ acc
+ });
+
+ deepseek::Request {
+ model,
+ messages: merged_messages,
+ stream: true,
+ max_tokens: max_output_tokens,
+ temperature: if is_reasoner {
+ None
+ } else {
+ request.temperature
+ },
+ response_format: None,
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| deepseek::ToolDefinition::Function {
+ function: deepseek::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ }
+}
+
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: Entity<State>,
@@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
- let request = request.into_google(self.model.id().to_string());
+ let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
@@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel {
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
- let request = request.into_google(self.model.id().to_string());
+ let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
@@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel {
}
}
+pub fn into_google(
+ request: LanguageModelRequest,
+ model: String,
+) -> google_ai::GenerateContentRequest {
+ google_ai::GenerateContentRequest {
+ model,
+ contents: request
+ .messages
+ .into_iter()
+ .map(|msg| google_ai::Content {
+ parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
+ text: msg.string_contents(),
+ })],
+ role: match msg.role {
+ Role::User => google_ai::Role::User,
+ Role::Assistant => google_ai::Role::Model,
+ Role::System => google_ai::Role::User, // Google AI doesn't have a system role
+ },
+ })
+ .collect(),
+ generation_config: Some(google_ai::GenerationConfig {
+ candidate_count: Some(1),
+ stop_sequences: Some(request.stop),
+ max_output_tokens: None,
+ temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+ top_p: None,
+ top_k: None,
+ }),
+ safety_settings: None,
+ }
+}
+
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
@@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
- let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
+ let request = into_mistral(
+ request,
+ self.model.id().to_string(),
+ self.max_output_tokens(),
+ );
let stream = self.stream_completion(request, cx);
async move {
@@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
+ let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens());
request.tools = vec![mistral::ToolDefinition::Function {
function: mistral::FunctionDefinition {
name: tool_name.clone(),
@@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel {
}
}
+pub fn into_mistral(
+ request: LanguageModelRequest,
+ model: String,
+ max_output_tokens: Option<u32>,
+) -> mistral::Request {
+ let len = request.messages.len();
+ let merged_messages =
+ request
+ .messages
+ .into_iter()
+ .fold(Vec::with_capacity(len), |mut acc, msg| {
+ let role = msg.role;
+ let content = msg.string_contents();
+
+ acc.push(match role {
+ Role::User => mistral::RequestMessage::User { content },
+ Role::Assistant => mistral::RequestMessage::Assistant {
+ content: Some(content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => mistral::RequestMessage::System { content },
+ });
+ acc
+ });
+
+ mistral::Request {
+ model,
+ messages: merged_messages,
+ stream: true,
+ max_tokens: max_output_tokens,
+ temperature: request.temperature,
+ response_format: None,
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| mistral::ToolDefinition::Function {
+ function: mistral::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ }
+}
+
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,
@@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel {
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
- let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
+ let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
Ok(open_ai::extract_text_from_events(completions.await?)
@@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
+ let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
@@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel {
}
}
+pub fn into_open_ai(
+ request: LanguageModelRequest,
+ model: String,
+ max_output_tokens: Option<u32>,
+) -> open_ai::Request {
+ let stream = !model.starts_with("o1-");
+ open_ai::Request {
+ model,
+ messages: request
+ .messages
+ .into_iter()
+ .map(|msg| match msg.role {
+ Role::User => open_ai::RequestMessage::User {
+ content: msg.string_contents(),
+ },
+ Role::Assistant => open_ai::RequestMessage::Assistant {
+ content: Some(msg.string_contents()),
+ tool_calls: Vec::new(),
+ },
+ Role::System => open_ai::RequestMessage::System {
+ content: msg.string_contents(),
+ },
+ })
+ .collect(),
+ stream,
+ stop: request.stop,
+ temperature: request.temperature.unwrap_or(1.0),
+ max_tokens: max_output_tokens,
+ tools: Vec::new(),
+ tool_choice: None,
+ }
+}
+
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,