Detailed changes
@@ -153,8 +153,8 @@ impl AssistantSettingsContent {
models
.into_iter()
.filter_map(|model| match model {
- open_ai::Model::Custom { name, max_tokens } => {
- Some(language_model::provider::open_ai::AvailableModel { name, max_tokens })
+ open_ai::Model::Custom { name, max_tokens,max_output_tokens } => {
+ Some(language_model::provider::open_ai::AvailableModel { name, max_tokens,max_output_tokens })
}
_ => None,
})
@@ -254,6 +254,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
@@ -513,7 +514,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
- let request = request.into_open_ai(model.id().into());
+ let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@@ -557,7 +558,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Zed(model) => {
let client = self.client.clone();
- let mut request = request.into_open_ai(model.id().into());
+ let mut request = request.into_open_ai(model.id().into(), None);
request.max_tokens = Some(4000);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
@@ -629,7 +630,8 @@ impl LanguageModel for CloudLanguageModel {
.boxed()
}
CloudModel::OpenAi(model) => {
- let mut request = request.into_open_ai(model.id().into());
+ let mut request =
+ request.into_open_ai(model.id().into(), model.max_output_tokens());
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
@@ -676,7 +678,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
- let mut request = request.into_open_ai(model.id().into());
+ let mut request = request.into_open_ai(model.id().into(), None);
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
@@ -40,6 +40,7 @@ pub struct OpenAiSettings {
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
+ pub max_output_tokens: Option<u32>,
}
pub struct OpenAiLanguageModelProvider {
@@ -170,6 +171,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
},
);
}
@@ -275,6 +277,10 @@ impl LanguageModel for OpenAiLanguageModel {
self.model.max_token_count()
}
+ fn max_output_tokens(&self) -> Option<u32> {
+ self.model.max_output_tokens()
+ }
+
fn count_tokens(
&self,
request: LanguageModelRequest,
@@ -288,7 +294,7 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let request = request.into_open_ai(self.model.id().into());
+ let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
@@ -301,7 +307,7 @@ impl LanguageModel for OpenAiLanguageModel {
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let mut request = request.into_open_ai(self.model.id().into());
+ let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
@@ -229,7 +229,7 @@ pub struct LanguageModelRequest {
}
impl LanguageModelRequest {
- pub fn into_open_ai(self, model: String) -> open_ai::Request {
+ pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
open_ai::Request {
model,
messages: self
@@ -251,7 +251,7 @@ impl LanguageModelRequest {
stream: true,
stop: self.stop,
temperature: self.temperature,
- max_tokens: None,
+ max_tokens: max_output_tokens,
tools: Vec::new(),
tool_choice: None,
}
@@ -172,9 +172,15 @@ impl OpenAiSettingsContent {
models
.into_iter()
.filter_map(|model| match model {
- open_ai::Model::Custom { name, max_tokens } => {
- Some(provider::open_ai::AvailableModel { name, max_tokens })
- }
+ open_ai::Model::Custom {
+ name,
+ max_tokens,
+ max_output_tokens,
+ } => Some(provider::open_ai::AvailableModel {
+ name,
+ max_tokens,
+ max_output_tokens,
+ }),
_ => None,
})
.collect()
@@ -66,7 +66,11 @@ pub enum Model {
#[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
FourOmniMini,
#[serde(rename = "custom")]
- Custom { name: String, max_tokens: usize },
+ Custom {
+ name: String,
+ max_tokens: usize,
+ max_output_tokens: Option<u32>,
+ },
}
impl Model {
@@ -113,6 +117,19 @@ impl Model {
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
+
+ pub fn max_output_tokens(&self) -> Option<u32> {
+ match self {
+ Self::ThreePointFiveTurbo => Some(4096),
+ Self::Four => Some(8192),
+ Self::FourTurbo => Some(4096),
+ Self::FourOmni => Some(4096),
+ Self::FourOmniMini => Some(16384),
+ Self::Custom {
+ max_output_tokens, ..
+ } => *max_output_tokens,
+ }
+ }
}
#[derive(Debug, Serialize, Deserialize)]
@@ -121,7 +138,7 @@ pub struct Request {
pub messages: Vec<RequestMessage>,
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
- pub max_tokens: Option<usize>,
+ pub max_tokens: Option<u32>,
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]