From 5905fbb9accdc5d34b7fec0fe021022a5b38420e Mon Sep 17 00:00:00 2001 From: Roy Williams Date: Fri, 20 Sep 2024 16:59:12 -0400 Subject: [PATCH] Allow Anthropic custom models to override temperature (#18160) Release Notes: - Allow Anthropic custom models to override "temperature" This also centralized the defaulting of "temperature" to be inside of each model's `into_x` call instead of being sprinkled around the code. --- crates/anthropic/src/anthropic.rs | 14 ++++++++++++++ crates/assistant/src/context.rs | 2 +- crates/assistant/src/inline_assistant.rs | 2 +- crates/assistant/src/prompt_library.rs | 2 +- .../assistant/src/slash_command/auto_command.rs | 2 +- .../assistant/src/terminal_inline_assistant.rs | 2 +- crates/language_model/src/provider/anthropic.rs | 10 ++++++++-- crates/language_model/src/provider/cloud.rs | 16 +++++++++++++--- crates/language_model/src/provider/ollama.rs | 2 +- crates/language_model/src/request.rs | 15 ++++++++++----- crates/language_model/src/settings.rs | 2 ++ crates/semantic_index/src/summary_index.rs | 2 +- 12 files changed, 54 insertions(+), 17 deletions(-) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index f960dc541a2866f4f37a8ea88c78a8e8b78f8310..91b6723e90be97c858ed3992dd5fcc47a418dcf7 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -49,6 +49,7 @@ pub enum Model { /// Indicates whether this custom model supports caching. cache_configuration: Option, max_output_tokens: Option, + default_temperature: Option, }, } @@ -124,6 +125,19 @@ impl Model { } } + pub fn default_temperature(&self) -> f32 { + match self { + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 1.0, + Self::Custom { + default_temperature, + .. + } => default_temperature.unwrap_or(1.0), + } + } + pub fn tool_model_id(&self) -> &str { if let Self::Custom { tool_override: Some(tool_override), diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 830c0980491f7c9764a08e578adf8c559e33195b..97a5b3ea988bccbdbe90d150d02c53d7e3a97c59 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -2180,7 +2180,7 @@ impl Context { messages: Vec::new(), tools: Vec::new(), stop: Vec::new(), - temperature: 1.0, + temperature: None, }; for message in self.messages(cx) { if message.status != MessageStatus::Done { diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index d95b54d3c6a6ace9ea3c2d25861c17183e624e66..f2428c3a2e94cf1ddcb0de291f999dd08f8f2291 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -2732,7 +2732,7 @@ impl CodegenAlternative { messages, tools: Vec::new(), stop: Vec::new(), - temperature: 1., + temperature: None, }) } diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 76ee95d5070b8269af61b80a80ffbc6e7c6a7798..24e20a18a799a690721d3efca15ac6aa684bb950 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -796,7 +796,7 @@ impl PromptLibrary { }], tools: Vec::new(), stop: Vec::new(), - temperature: 1., + temperature: None, }, cx, ) diff --git a/crates/assistant/src/slash_command/auto_command.rs b/crates/assistant/src/slash_command/auto_command.rs index e1f20c311bd36e48f15d1b3b94ef6ee6c8b62082..14cee296820989d4afacc8eec2ed4231ec591408 100644 --- a/crates/assistant/src/slash_command/auto_command.rs +++ b/crates/assistant/src/slash_command/auto_command.rs @@ -216,7 +216,7 @@ async fn commands_for_summaries( }], tools: Vec::new(), stop: Vec::new(), - temperature: 1.0, + temperature: None, }; while let Some(current_summaries) = stack.pop() { diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index caf819bae535ee88bd2aed3eaa004f7b79c930c8..e1a26d851003eb572cbdafd9156afb9978bbd489 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -284,7 +284,7 @@ impl TerminalInlineAssistant { messages, tools: Vec::new(), stop: Vec::new(), - temperature: 1.0, + temperature: None, }) } diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 1e3d2750949f16f97fa04f88b48474b30edbd54a..86538bec49172d88aa87375f6335fddac94ea27c 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -51,6 +51,7 @@ pub struct AvailableModel { /// Configuration of Anthropic's caching API. pub cache_configuration: Option, pub max_output_tokens: Option, + pub default_temperature: Option, } pub struct AnthropicLanguageModelProvider { @@ -200,6 +201,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } }), max_output_tokens: model.max_output_tokens, + default_temperature: model.default_temperature, }, ); } @@ -375,8 +377,11 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = - request.into_anthropic(self.model.id().into(), self.model.max_output_tokens()); + let request = request.into_anthropic( + self.model.id().into(), + self.model.default_temperature(), + self.model.max_output_tokens(), + ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { let response = request.await.map_err(|err| anyhow!(err))?; @@ -405,6 +410,7 @@ impl LanguageModel for AnthropicModel { ) -> BoxFuture<'static, Result>>> { let mut request = request.into_anthropic( self.model.tool_model_id().into(), + self.model.default_temperature(), self.model.max_output_tokens(), ); request.tool_choice = Some(anthropic::ToolChoice::Tool { diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 58efb4cfe1b308decf41cd657cb6dbe5f3a02cec..606a6fbacec7b0e0dd6033608856685a189549f9 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -87,6 +87,8 @@ pub struct AvailableModel { pub tool_override: Option, /// Indicates whether this custom model supports caching. pub cache_configuration: Option, + /// The default temperature to use for this model. + pub default_temperature: Option, } pub struct CloudLanguageModelProvider { @@ -255,6 +257,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { min_total_token: config.min_total_token, } }), + default_temperature: model.default_temperature, max_output_tokens: model.max_output_tokens, }), AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { @@ -516,7 +519,11 @@ impl LanguageModel for CloudLanguageModel { match &self.model { CloudModel::Anthropic(model) => { - let request = request.into_anthropic(model.id().into(), model.max_output_tokens()); + let request = request.into_anthropic( + model.id().into(), + model.default_temperature(), + model.max_output_tokens(), + ); let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { @@ -642,8 +649,11 @@ impl LanguageModel for CloudLanguageModel { match &self.model { CloudModel::Anthropic(model) => { - let mut request = - request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens()); + let mut request = request.into_anthropic( + model.tool_model_id().into(), + model.default_temperature(), + model.max_output_tokens(), + ); request.tool_choice = Some(anthropic::ToolChoice::Tool { name: tool_name.clone(), }); diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 6a3190dee7c229a26d5c73e83e67677e45da9620..a29ff3cf6a7a1a34cbe10bec99583a8cee5a5b00 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -235,7 +235,7 @@ impl OllamaLanguageModel { options: Some(ChatOptions { num_ctx: Some(self.model.max_tokens), stop: Some(request.stop), - temperature: Some(request.temperature), + temperature: request.temperature.or(Some(1.0)), ..Default::default() }), tools: vec![], diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index dd480b8aaf38c22a99e8878b4dd8db73f2a125a9..06dde1862ab37ed2a4fbec4a8e67cb1bd18254cf 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -236,7 +236,7 @@ pub struct LanguageModelRequest { pub messages: Vec, pub tools: Vec, pub stop: Vec, - pub temperature: f32, + pub temperature: Option, } impl LanguageModelRequest { @@ -262,7 +262,7 @@ impl LanguageModelRequest { .collect(), stream, stop: self.stop, - temperature: self.temperature, + temperature: self.temperature.unwrap_or(1.0), max_tokens: max_output_tokens, tools: Vec::new(), tool_choice: None, @@ -290,7 +290,7 @@ impl LanguageModelRequest { candidate_count: Some(1), stop_sequences: Some(self.stop), max_output_tokens: None, - temperature: Some(self.temperature as f64), + temperature: self.temperature.map(|t| t as f64).or(Some(1.0)), top_p: None, top_k: None, }), @@ -298,7 +298,12 @@ impl LanguageModelRequest { } } - pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request { + pub fn into_anthropic( + self, + model: String, + default_temperature: f32, + max_output_tokens: u32, + ) -> anthropic::Request { let mut new_messages: Vec = Vec::new(); let mut system_message = String::new(); @@ -400,7 +405,7 @@ impl LanguageModelRequest { tool_choice: None, metadata: None, stop_sequences: Vec::new(), - temperature: Some(self.temperature), + temperature: self.temperature.or(Some(default_temperature)), top_k: None, top_p: None, } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 8888d51e11c25524cf15beea129973ca033357a5..2bf8deb04238c290e2ff519eb0c0e94371a66b64 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -99,6 +99,7 @@ impl AnthropicSettingsContent { tool_override, cache_configuration, max_output_tokens, + default_temperature, } => Some(provider::anthropic::AvailableModel { name, display_name, @@ -112,6 +113,7 @@ impl AnthropicSettingsContent { }, ), max_output_tokens, + default_temperature, }), _ => None, }) diff --git a/crates/semantic_index/src/summary_index.rs b/crates/semantic_index/src/summary_index.rs index 08f25ae0287fa21c0e897acabc7611110a0b782d..f4c6d4726c508b8049f54634633f7b09bb6c30d6 100644 --- a/crates/semantic_index/src/summary_index.rs +++ b/crates/semantic_index/src/summary_index.rs @@ -562,7 +562,7 @@ impl SummaryIndex { }], tools: Vec::new(), stop: Vec::new(), - temperature: 1.0, + temperature: None, }; let code_len = code.len();