Add ability to use o1-preview and o1-mini as custom models (#17804)

jvmncs , Peter , Bennet , and Marshall Bowers created

This is a barebones modification of the OpenAI provider code to
accommodate non-streaming completions. This is specifically for the o1
models, which do not support streaming. Tested that this is working by
running a `/workflow` with the following (arbitrarily chosen) settings:

```json
{
  "language_models": {
    "openai": {
      "version": "1",
      "available_models": [
        {
          "name": "o1-preview",
          "display_name": "o1-preview",
          "max_tokens": 128000,
          "max_completion_tokens": 30000
        },
        {
          "name": "o1-mini",
          "display_name": "o1-mini",
          "max_tokens": 128000,
          "max_completion_tokens": 20000
        }
      ]
    }
  },
}
```

Release Notes:

- Changed  `low_speed_timeout_in_seconds` option to `600` for OpenAI
provider to accommodate recent o1 model release.

---------

Co-authored-by: Peter <peter@zed.dev>
Co-authored-by: Bennet <bennet@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

assets/settings/default.json                  |   3 
crates/assistant/src/assistant_settings.rs    |   2 
crates/assistant/src/inline_assistant.rs      |   2 
crates/language_model/src/provider/cloud.rs   |   3 
crates/language_model/src/provider/open_ai.rs |   2 
crates/language_model/src/settings.rs         |   2 
crates/open_ai/src/open_ai.rs                 | 126 ++++++++++++++++++++
7 files changed, 136 insertions(+), 4 deletions(-)

Detailed changes

assets/settings/default.json 🔗

@@ -916,7 +916,8 @@
     },
     "openai": {
       "version": "1",
-      "api_url": "https://api.openai.com/v1"
+      "api_url": "https://api.openai.com/v1",
+      "low_speed_timeout_in_seconds": 600
     }
   },
   // Zed's Prettier integration settings.

crates/assistant/src/assistant_settings.rs 🔗

@@ -163,11 +163,13 @@ impl AssistantSettingsContent {
                                                     display_name,
                                                     max_tokens,
                                                     max_output_tokens,
+                                                    max_completion_tokens: None,
                                                 } => Some(open_ai::AvailableModel {
                                                     name,
                                                     display_name,
                                                     max_tokens,
                                                     max_output_tokens,
+                                                    max_completion_tokens: None,
                                                 }),
                                                 _ => None,
                                             })

crates/language_model/src/provider/cloud.rs 🔗

@@ -78,6 +78,8 @@ pub struct AvailableModel {
     pub max_tokens: usize,
     /// The maximum number of output tokens allowed by the model.
     pub max_output_tokens: Option<u32>,
+    /// The maximum number of completion tokens allowed by the model (o1-* only)
+    pub max_completion_tokens: Option<u32>,
     /// Override this model with a different Anthropic model for tool calls.
     pub tool_override: Option<String>,
     /// Indicates whether this custom model supports caching.
@@ -257,6 +259,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                     display_name: model.display_name.clone(),
                     max_tokens: model.max_tokens,
                     max_output_tokens: model.max_output_tokens,
+                    max_completion_tokens: model.max_completion_tokens,
                 }),
                 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
                     name: model.name.clone(),

crates/language_model/src/provider/open_ai.rs 🔗

@@ -43,6 +43,7 @@ pub struct AvailableModel {
     pub display_name: Option<String>,
     pub max_tokens: usize,
     pub max_output_tokens: Option<u32>,
+    pub max_completion_tokens: Option<u32>,
 }
 
 pub struct OpenAiLanguageModelProvider {
@@ -175,6 +176,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
                     display_name: model.display_name.clone(),
                     max_tokens: model.max_tokens,
                     max_output_tokens: model.max_output_tokens,
+                    max_completion_tokens: model.max_completion_tokens,
                 },
             );
         }

crates/language_model/src/settings.rs 🔗

@@ -178,11 +178,13 @@ impl OpenAiSettingsContent {
                                     display_name,
                                     max_tokens,
                                     max_output_tokens,
+                                    max_completion_tokens,
                                 } => Some(provider::open_ai::AvailableModel {
                                     name,
                                     max_tokens,
                                     max_output_tokens,
                                     display_name,
+                                    max_completion_tokens,
                                 }),
                                 _ => None,
                             })

crates/open_ai/src/open_ai.rs 🔗

@@ -1,12 +1,21 @@
 mod supported_countries;
 
 use anyhow::{anyhow, Context, Result};
-use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use futures::{
+    io::BufReader,
+    stream::{self, BoxStream},
+    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
+};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
-use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
+use std::{
+    convert::TryFrom,
+    future::{self, Future},
+    pin::Pin,
+    time::Duration,
+};
 use strum::EnumIter;
 
 pub use supported_countries::*;
@@ -72,6 +81,7 @@ pub enum Model {
         display_name: Option<String>,
         max_tokens: usize,
         max_output_tokens: Option<u32>,
+        max_completion_tokens: Option<u32>,
     },
 }
 
@@ -139,6 +149,7 @@ pub struct Request {
     pub stream: bool,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub max_tokens: Option<u32>,
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
     pub stop: Vec<String>,
     pub temperature: f32,
     #[serde(default, skip_serializing_if = "Option::is_none")]
@@ -263,6 +274,111 @@ pub struct ResponseStreamEvent {
     pub usage: Option<Usage>,
 }
 
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Response {
+    pub id: String,
+    pub object: String,
+    pub created: u64,
+    pub model: String,
+    pub choices: Vec<Choice>,
+    pub usage: Usage,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Choice {
+    pub index: u32,
+    pub message: RequestMessage,
+    pub finish_reason: Option<String>,
+}
+
+pub async fn complete(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+    low_speed_timeout: Option<Duration>,
+) -> Result<Response> {
+    let uri = format!("{api_url}/chat/completions");
+    let mut request_builder = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key));
+    if let Some(low_speed_timeout) = low_speed_timeout {
+        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+    };
+
+    let mut request_body = request;
+    request_body.stream = false;
+
+    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
+    let mut response = client.send(request).await?;
+
+    if response.status().is_success() {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+        let response: Response = serde_json::from_str(&body)?;
+        Ok(response)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAiResponse {
+            error: OpenAiError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAiError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAiResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}
+
+fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
+    ResponseStreamEvent {
+        created: response.created as u32,
+        model: response.model,
+        choices: response
+            .choices
+            .into_iter()
+            .map(|choice| ChoiceDelta {
+                index: choice.index,
+                delta: ResponseMessageDelta {
+                    role: Some(match choice.message {
+                        RequestMessage::Assistant { .. } => Role::Assistant,
+                        RequestMessage::User { .. } => Role::User,
+                        RequestMessage::System { .. } => Role::System,
+                        RequestMessage::Tool { .. } => Role::Tool,
+                    }),
+                    content: match choice.message {
+                        RequestMessage::Assistant { content, .. } => content,
+                        RequestMessage::User { content } => Some(content),
+                        RequestMessage::System { content } => Some(content),
+                        RequestMessage::Tool { content, .. } => Some(content),
+                    },
+                    tool_calls: None,
+                },
+                finish_reason: choice.finish_reason,
+            })
+            .collect(),
+        usage: Some(response.usage),
+    }
+}
+
 pub async fn stream_completion(
     client: &dyn HttpClient,
     api_url: &str,
@@ -270,6 +386,12 @@ pub async fn stream_completion(
     request: Request,
     low_speed_timeout: Option<Duration>,
 ) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+    if request.model == "o1-preview" || request.model == "o1-mini" {
+        let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
+        let response_stream_event = response.map(adapt_response_to_stream);
+        return Ok(stream::once(future::ready(response_stream_event)).boxed());
+    }
+
     let uri = format!("{api_url}/chat/completions");
     let mut request_builder = HttpRequest::builder()
         .method(Method::POST)