Allow Anthropic custom models to override temperature (#18160)

Roy Williams created

Release Notes:

- Allow Anthropic custom models to override "temperature"

This also centralized the defaulting of "temperature" to be inside of
each model's `into_x` call instead of being sprinkled around the code.

Change summary

crates/anthropic/src/anthropic.rs                  | 14 ++++++++++++++
crates/assistant/src/context.rs                    |  2 +-
crates/assistant/src/inline_assistant.rs           |  2 +-
crates/assistant/src/prompt_library.rs             |  2 +-
crates/assistant/src/slash_command/auto_command.rs |  2 +-
crates/assistant/src/terminal_inline_assistant.rs  |  2 +-
crates/language_model/src/provider/anthropic.rs    | 10 ++++++++--
crates/language_model/src/provider/cloud.rs        | 16 +++++++++++++---
crates/language_model/src/provider/ollama.rs       |  2 +-
crates/language_model/src/request.rs               | 15 ++++++++++-----
crates/language_model/src/settings.rs              |  2 ++
crates/semantic_index/src/summary_index.rs         |  2 +-
12 files changed, 54 insertions(+), 17 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -49,6 +49,7 @@ pub enum Model {
         /// Indicates whether this custom model supports caching.
         cache_configuration: Option<AnthropicModelCacheConfiguration>,
         max_output_tokens: Option<u32>,
+        default_temperature: Option<f32>,
     },
 }
 
@@ -124,6 +125,19 @@ impl Model {
         }
     }
 
+    pub fn default_temperature(&self) -> f32 {
+        match self {
+            Self::Claude3_5Sonnet
+            | Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3Haiku => 1.0,
+            Self::Custom {
+                default_temperature,
+                ..
+            } => default_temperature.unwrap_or(1.0),
+        }
+    }
+
     pub fn tool_model_id(&self) -> &str {
         if let Self::Custom {
             tool_override: Some(tool_override),

crates/assistant/src/context.rs 🔗

@@ -2180,7 +2180,7 @@ impl Context {
             messages: Vec::new(),
             tools: Vec::new(),
             stop: Vec::new(),
-            temperature: 1.0,
+            temperature: None,
         };
         for message in self.messages(cx) {
             if message.status != MessageStatus::Done {

crates/language_model/src/provider/anthropic.rs 🔗

@@ -51,6 +51,7 @@ pub struct AvailableModel {
     /// Configuration of Anthropic's caching API.
     pub cache_configuration: Option<LanguageModelCacheConfiguration>,
     pub max_output_tokens: Option<u32>,
+    pub default_temperature: Option<f32>,
 }
 
 pub struct AnthropicLanguageModelProvider {
@@ -200,6 +201,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
                         }
                     }),
                     max_output_tokens: model.max_output_tokens,
+                    default_temperature: model.default_temperature,
                 },
             );
         }
@@ -375,8 +377,11 @@ impl LanguageModel for AnthropicModel {
         request: LanguageModelRequest,
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
-        let request =
-            request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
+        let request = request.into_anthropic(
+            self.model.id().into(),
+            self.model.default_temperature(),
+            self.model.max_output_tokens(),
+        );
         let request = self.stream_completion(request, cx);
         let future = self.request_limiter.stream(async move {
             let response = request.await.map_err(|err| anyhow!(err))?;
@@ -405,6 +410,7 @@ impl LanguageModel for AnthropicModel {
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         let mut request = request.into_anthropic(
             self.model.tool_model_id().into(),
+            self.model.default_temperature(),
             self.model.max_output_tokens(),
         );
         request.tool_choice = Some(anthropic::ToolChoice::Tool {

crates/language_model/src/provider/cloud.rs 🔗

@@ -87,6 +87,8 @@ pub struct AvailableModel {
     pub tool_override: Option<String>,
     /// Indicates whether this custom model supports caching.
     pub cache_configuration: Option<LanguageModelCacheConfiguration>,
+    /// The default temperature to use for this model.
+    pub default_temperature: Option<f32>,
 }
 
 pub struct CloudLanguageModelProvider {
@@ -255,6 +257,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                             min_total_token: config.min_total_token,
                         }
                     }),
+                    default_temperature: model.default_temperature,
                     max_output_tokens: model.max_output_tokens,
                 }),
                 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
@@ -516,7 +519,11 @@ impl LanguageModel for CloudLanguageModel {
 
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
+                let request = request.into_anthropic(
+                    model.id().into(),
+                    model.default_temperature(),
+                    model.max_output_tokens(),
+                );
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream(async move {
@@ -642,8 +649,11 @@ impl LanguageModel for CloudLanguageModel {
 
         match &self.model {
             CloudModel::Anthropic(model) => {
-                let mut request =
-                    request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
+                let mut request = request.into_anthropic(
+                    model.tool_model_id().into(),
+                    model.default_temperature(),
+                    model.max_output_tokens(),
+                );
                 request.tool_choice = Some(anthropic::ToolChoice::Tool {
                     name: tool_name.clone(),
                 });

crates/language_model/src/provider/ollama.rs 🔗

@@ -235,7 +235,7 @@ impl OllamaLanguageModel {
             options: Some(ChatOptions {
                 num_ctx: Some(self.model.max_tokens),
                 stop: Some(request.stop),
-                temperature: Some(request.temperature),
+                temperature: request.temperature.or(Some(1.0)),
                 ..Default::default()
             }),
             tools: vec![],

crates/language_model/src/request.rs 🔗

@@ -236,7 +236,7 @@ pub struct LanguageModelRequest {
     pub messages: Vec<LanguageModelRequestMessage>,
     pub tools: Vec<LanguageModelRequestTool>,
     pub stop: Vec<String>,
-    pub temperature: f32,
+    pub temperature: Option<f32>,
 }
 
 impl LanguageModelRequest {
@@ -262,7 +262,7 @@ impl LanguageModelRequest {
                 .collect(),
             stream,
             stop: self.stop,
-            temperature: self.temperature,
+            temperature: self.temperature.unwrap_or(1.0),
             max_tokens: max_output_tokens,
             tools: Vec::new(),
             tool_choice: None,
@@ -290,7 +290,7 @@ impl LanguageModelRequest {
                 candidate_count: Some(1),
                 stop_sequences: Some(self.stop),
                 max_output_tokens: None,
-                temperature: Some(self.temperature as f64),
+                temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
                 top_p: None,
                 top_k: None,
             }),
@@ -298,7 +298,12 @@ impl LanguageModelRequest {
         }
     }
 
-    pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
+    pub fn into_anthropic(
+        self,
+        model: String,
+        default_temperature: f32,
+        max_output_tokens: u32,
+    ) -> anthropic::Request {
         let mut new_messages: Vec<anthropic::Message> = Vec::new();
         let mut system_message = String::new();
 
@@ -400,7 +405,7 @@ impl LanguageModelRequest {
             tool_choice: None,
             metadata: None,
             stop_sequences: Vec::new(),
-            temperature: Some(self.temperature),
+            temperature: self.temperature.or(Some(default_temperature)),
             top_k: None,
             top_p: None,
         }

crates/language_model/src/settings.rs 🔗

@@ -99,6 +99,7 @@ impl AnthropicSettingsContent {
                                     tool_override,
                                     cache_configuration,
                                     max_output_tokens,
+                                    default_temperature,
                                 } => Some(provider::anthropic::AvailableModel {
                                     name,
                                     display_name,
@@ -112,6 +113,7 @@ impl AnthropicSettingsContent {
                                         },
                                     ),
                                     max_output_tokens,
+                                    default_temperature,
                                 }),
                                 _ => None,
                             })