Wire up Azure OpenAI completion provider (#8646)

Marshall Bowers created

This PR wires up support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
as an alternative AI provider in the assistant panel.

This can be configured using the following in the settings file:

```json
{
  "assistant": {
    "provider": {
      "type": "azure_openai",
      "api_url": "https://{your-resource-name}.openai.azure.com",
      "deployment_id": "gpt-4",
      "api_version": "2023-05-15"
    }
  },
}
```

You will need to deploy a model within Azure and update the settings
accordingly.

Release Notes:

- N/A

Change summary

Cargo.lock                                    |   1 
assets/settings/default.json                  |  16 +
crates/ai/Cargo.toml                          |   1 
crates/ai/src/providers/open_ai/completion.rs |  67 ++++++-
crates/assistant/src/assistant_panel.rs       |  77 ++++---
crates/assistant/src/assistant_settings.rs    | 193 ++++++++++++++++++--
crates/client/src/telemetry.rs                |   2 
7 files changed, 291 insertions(+), 66 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -93,6 +93,7 @@ dependencies = [
  "postage",
  "rand 0.8.5",
  "rusqlite",
+ "schemars",
  "serde",
  "serde_json",
  "tiktoken-rs",

assets/settings/default.json 🔗

@@ -228,15 +228,29 @@
     "default_width": 640,
     // Default height when the assistant is docked to the bottom.
     "default_height": 320,
+    // Deprecated: Please use `provider.api_url` instead.
     // The default OpenAI API endpoint to use when starting new conversations.
     "openai_api_url": "https://api.openai.com/v1",
+    // Deprecated: Please use `provider.default_model` instead.
     // The default OpenAI model to use when starting new conversations. This
     // setting can take three values:
     //
     // 1. "gpt-3.5-turbo-0613""
     // 2. "gpt-4-0613""
     // 3. "gpt-4-1106-preview"
-    "default_open_ai_model": "gpt-4-1106-preview"
+    "default_open_ai_model": "gpt-4-1106-preview",
+    "provider": {
+      "type": "openai",
+      // The default OpenAI API endpoint to use when starting new conversations.
+      "api_url": "https://api.openai.com/v1",
+      // The default OpenAI model to use when starting new conversations. This
+      // setting can take three values:
+      //
+      // 1. "gpt-3.5-turbo-0613""
+      // 2. "gpt-4-0613""
+      // 3. "gpt-4-1106-preview"
+      "default_model": "gpt-4-1106-preview"
+    }
   },
   // Whether the screen sharing icon is shown in the os status bar.
   "show_call_status_icon": true,

crates/ai/Cargo.toml 🔗

@@ -29,6 +29,7 @@ parse_duration = "2.1.1"
 postage.workspace = true
 rand.workspace = true
 rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 tiktoken-rs.workspace = true

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -1,3 +1,10 @@
+use std::{
+    env,
+    fmt::{self, Display},
+    io,
+    sync::Arc,
+};
+
 use anyhow::{anyhow, Result};
 use futures::{
     future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
@@ -6,23 +13,17 @@ use futures::{
 use gpui::{AppContext, BackgroundExecutor};
 use isahc::{http::StatusCode, Request, RequestExt};
 use parking_lot::RwLock;
+use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use std::{
-    env,
-    fmt::{self, Display},
-    io,
-    sync::Arc,
-};
 use util::ResultExt;
 
+use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
 use crate::{
     auth::{CredentialProvider, ProviderCredential},
     completion::{CompletionProvider, CompletionRequest},
     models::LanguageModel,
 };
 
-use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
-
 #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(rename_all = "lowercase")]
 pub enum Role {
@@ -196,12 +197,56 @@ async fn stream_completion(
     }
 }
 
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
+pub enum AzureOpenAiApiVersion {
+    /// Retiring April 2, 2024.
+    #[serde(rename = "2023-03-15-preview")]
+    V2023_03_15Preview,
+    #[serde(rename = "2023-05-15")]
+    V2023_05_15,
+    /// Retiring April 2, 2024.
+    #[serde(rename = "2023-06-01-preview")]
+    V2023_06_01Preview,
+    /// Retiring April 2, 2024.
+    #[serde(rename = "2023-07-01-preview")]
+    V2023_07_01Preview,
+    /// Retiring April 2, 2024.
+    #[serde(rename = "2023-08-01-preview")]
+    V2023_08_01Preview,
+    /// Retiring April 2, 2024.
+    #[serde(rename = "2023-09-01-preview")]
+    V2023_09_01Preview,
+    #[serde(rename = "2023-12-01-preview")]
+    V2023_12_01Preview,
+    #[serde(rename = "2024-02-15-preview")]
+    V2024_02_15Preview,
+}
+
+impl fmt::Display for AzureOpenAiApiVersion {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(
+            f,
+            "{}",
+            match self {
+                Self::V2023_03_15Preview => "2023-03-15-preview",
+                Self::V2023_05_15 => "2023-05-15",
+                Self::V2023_06_01Preview => "2023-06-01-preview",
+                Self::V2023_07_01Preview => "2023-07-01-preview",
+                Self::V2023_08_01Preview => "2023-08-01-preview",
+                Self::V2023_09_01Preview => "2023-09-01-preview",
+                Self::V2023_12_01Preview => "2023-12-01-preview",
+                Self::V2024_02_15Preview => "2024-02-15-preview",
+            }
+        )
+    }
+}
+
 #[derive(Clone)]
 pub enum OpenAiCompletionProviderKind {
     OpenAi,
     AzureOpenAi {
         deployment_id: String,
-        api_version: String,
+        api_version: AzureOpenAiApiVersion,
     },
 }
 
@@ -217,8 +262,8 @@ impl OpenAiCompletionProviderKind {
                 deployment_id,
                 api_version,
             } => {
-                // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
-                format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
+                // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
+                format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
             }
         }
     }

crates/assistant/src/assistant_panel.rs 🔗

@@ -124,16 +124,18 @@ impl AssistantPanel {
                 .await
                 .log_err()
                 .unwrap_or_default();
-            let (api_url, model_name) = cx.update(|cx| {
+            let (provider_kind, api_url, model_name) = cx.update(|cx| {
                 let settings = AssistantSettings::get_global(cx);
-                (
-                    settings.openai_api_url.clone(),
-                    settings.default_open_ai_model.full_name().to_string(),
-                )
-            })?;
+                anyhow::Ok((
+                    settings.provider_kind()?,
+                    settings.provider_api_url()?,
+                    settings.provider_model_name()?,
+                ))
+            })??;
+
             let completion_provider = OpenAiCompletionProvider::new(
                 api_url,
-                OpenAiCompletionProviderKind::OpenAi,
+                provider_kind,
                 model_name,
                 cx.background_executor().clone(),
             )
@@ -693,24 +695,29 @@ impl AssistantPanel {
             Task::ready(Ok(Vec::new()))
         };
 
-        let mut model = AssistantSettings::get_global(cx)
-            .default_open_ai_model
-            .clone();
-        let model_name = model.full_name();
-
-        let prompt = cx.background_executor().spawn(async move {
-            let snippets = snippets.await?;
+        let Some(mut model_name) = AssistantSettings::get_global(cx)
+            .provider_model_name()
+            .log_err()
+        else {
+            return;
+        };
 
-            let language_name = language_name.as_deref();
-            generate_content_prompt(
-                user_prompt,
-                language_name,
-                buffer,
-                range,
-                snippets,
-                model_name,
-                project_name,
-            )
+        let prompt = cx.background_executor().spawn({
+            let model_name = model_name.clone();
+            async move {
+                let snippets = snippets.await?;
+
+                let language_name = language_name.as_deref();
+                generate_content_prompt(
+                    user_prompt,
+                    language_name,
+                    buffer,
+                    range,
+                    snippets,
+                    &model_name,
+                    project_name,
+                )
+            }
         });
 
         let mut messages = Vec::new();
@@ -722,7 +729,7 @@ impl AssistantPanel {
                     .messages(cx)
                     .map(|message| message.to_open_ai_message(buffer)),
             );
-            model = conversation.model.clone();
+            model_name = conversation.model.full_name().to_string();
         }
 
         cx.spawn(|_, mut cx| async move {
@@ -735,7 +742,7 @@ impl AssistantPanel {
             });
 
             let request = Box::new(OpenAiRequest {
-                model: model.full_name().into(),
+                model: model_name,
                 messages,
                 stream: true,
                 stop: vec!["|END|>".to_string()],
@@ -1454,8 +1461,14 @@ impl Conversation {
         });
 
         let settings = AssistantSettings::get_global(cx);
-        let model = settings.default_open_ai_model.clone();
-        let api_url = settings.openai_api_url.clone();
+        let model = settings
+            .provider_model()
+            .log_err()
+            .unwrap_or(OpenAiModel::FourTurbo);
+        let api_url = settings
+            .provider_api_url()
+            .log_err()
+            .unwrap_or_else(|| OPEN_AI_API_URL.to_string());
 
         let mut this = Self {
             id: Some(Uuid::new_v4().to_string()),
@@ -3655,9 +3668,9 @@ fn report_assistant_event(
     let client = workspace.read(cx).project().read(cx).client();
     let telemetry = client.telemetry();
 
-    let model = AssistantSettings::get_global(cx)
-        .default_open_ai_model
-        .clone();
+    let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else {
+        return;
+    };
 
-    telemetry.report_assistant_event(conversation_id, assistant_kind, model.full_name())
+    telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name)
 }

crates/assistant/src/assistant_settings.rs 🔗

@@ -1,10 +1,14 @@
-use anyhow;
+use ai::providers::open_ai::{
+    AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
+};
+use anyhow::anyhow;
 use gpui::Pixels;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::Settings;
 
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(rename_all = "snake_case")]
 pub enum OpenAiModel {
     #[serde(rename = "gpt-3.5-turbo-0613")]
     ThreePointFiveTurbo,
@@ -17,25 +21,25 @@ pub enum OpenAiModel {
 impl OpenAiModel {
     pub fn full_name(&self) -> &'static str {
         match self {
-            OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
-            OpenAiModel::Four => "gpt-4-0613",
-            OpenAiModel::FourTurbo => "gpt-4-1106-preview",
+            Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
+            Self::Four => "gpt-4-0613",
+            Self::FourTurbo => "gpt-4-1106-preview",
         }
     }
 
     pub fn short_name(&self) -> &'static str {
         match self {
-            OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
-            OpenAiModel::Four => "gpt-4",
-            OpenAiModel::FourTurbo => "gpt-4-turbo",
+            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+            Self::Four => "gpt-4",
+            Self::FourTurbo => "gpt-4-turbo",
         }
     }
 
     pub fn cycle(&self) -> Self {
         match self {
-            OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four,
-            OpenAiModel::Four => OpenAiModel::FourTurbo,
-            OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo,
+            Self::ThreePointFiveTurbo => Self::Four,
+            Self::Four => Self::FourTurbo,
+            Self::FourTurbo => Self::ThreePointFiveTurbo,
         }
     }
 }
@@ -48,14 +52,99 @@ pub enum AssistantDockPosition {
     Bottom,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Deserialize)]
 pub struct AssistantSettings {
+    /// Whether to show the assistant panel button in the status bar.
     pub button: bool,
+    /// Where to dock the assistant.
     pub dock: AssistantDockPosition,
+    /// Default width in pixels when the assistant is docked to the left or right.
     pub default_width: Pixels,
+    /// Default height in pixels when the assistant is docked to the bottom.
     pub default_height: Pixels,
+    /// The default OpenAI model to use when starting new conversations.
+    #[deprecated = "Please use `provider.default_model` instead."]
     pub default_open_ai_model: OpenAiModel,
+    /// OpenAI API base URL to use when starting new conversations.
+    #[deprecated = "Please use `provider.api_url` instead."]
     pub openai_api_url: String,
+    /// The settings for the AI provider.
+    pub provider: AiProviderSettings,
+}
+
+impl AssistantSettings {
+    pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
+        match &self.provider {
+            AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
+            AiProviderSettings::AzureOpenAi(settings) => {
+                let deployment_id = settings
+                    .deployment_id
+                    .clone()
+                    .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
+                let api_version = settings
+                    .api_version
+                    .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
+
+                Ok(OpenAiCompletionProviderKind::AzureOpenAi {
+                    deployment_id,
+                    api_version,
+                })
+            }
+        }
+    }
+
+    pub fn provider_api_url(&self) -> anyhow::Result<String> {
+        match &self.provider {
+            AiProviderSettings::OpenAi(settings) => Ok(settings
+                .api_url
+                .clone()
+                .unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
+            AiProviderSettings::AzureOpenAi(settings) => settings
+                .api_url
+                .clone()
+                .ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
+        }
+    }
+
+    pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
+        match &self.provider {
+            AiProviderSettings::OpenAi(settings) => {
+                Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
+            }
+            AiProviderSettings::AzureOpenAi(_settings) => {
+                // TODO: We need to use an Azure OpenAI model here.
+                Ok(OpenAiModel::FourTurbo)
+            }
+        }
+    }
+
+    pub fn provider_model_name(&self) -> anyhow::Result<String> {
+        match &self.provider {
+            AiProviderSettings::OpenAi(settings) => Ok(settings
+                .default_model
+                .unwrap_or(OpenAiModel::FourTurbo)
+                .full_name()
+                .to_string()),
+            AiProviderSettings::AzureOpenAi(settings) => settings
+                .deployment_id
+                .clone()
+                .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
+        }
+    }
+}
+
+impl Settings for AssistantSettings {
+    const KEY: Option<&'static str> = Some("assistant");
+
+    type FileContent = AssistantSettingsContent;
+
+    fn load(
+        default_value: &Self::FileContent,
+        user_values: &[&Self::FileContent],
+        _: &mut gpui::AppContext,
+    ) -> anyhow::Result<Self> {
+        Self::load_via_json_merge(default_value, user_values)
+    }
 }
 
 /// Assistant panel settings
@@ -77,26 +166,88 @@ pub struct AssistantSettingsContent {
     ///
     /// Default: 320
     pub default_height: Option<f32>,
+    /// Deprecated: Please use `provider.default_model` instead.
     /// The default OpenAI model to use when starting new conversations.
     ///
     /// Default: gpt-4-1106-preview
+    #[deprecated = "Please use `provider.default_model` instead."]
     pub default_open_ai_model: Option<OpenAiModel>,
+    /// Deprecated: Please use `provider.api_url` instead.
     /// OpenAI API base URL to use when starting new conversations.
     ///
     /// Default: https://api.openai.com/v1
+    #[deprecated = "Please use `provider.api_url` instead."]
     pub openai_api_url: Option<String>,
+    /// The settings for the AI provider.
+    #[serde(default)]
+    pub provider: AiProviderSettingsContent,
 }
 
-impl Settings for AssistantSettings {
-    const KEY: Option<&'static str> = Some("assistant");
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum AiProviderSettings {
+    /// The settings for the OpenAI provider.
+    #[serde(rename = "openai")]
+    OpenAi(OpenAiProviderSettings),
+    /// The settings for the Azure OpenAI provider.
+    #[serde(rename = "azure_openai")]
+    AzureOpenAi(AzureOpenAiProviderSettings),
+}
 
-    type FileContent = AssistantSettingsContent;
+/// The settings for the AI provider used by the Zed Assistant.
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum AiProviderSettingsContent {
+    /// The settings for the OpenAI provider.
+    #[serde(rename = "openai")]
+    OpenAi(OpenAiProviderSettingsContent),
+    /// The settings for the Azure OpenAI provider.
+    #[serde(rename = "azure_openai")]
+    AzureOpenAi(AzureOpenAiProviderSettingsContent),
+}
 
-    fn load(
-        default_value: &Self::FileContent,
-        user_values: &[&Self::FileContent],
-        _: &mut gpui::AppContext,
-    ) -> anyhow::Result<Self> {
-        Self::load_via_json_merge(default_value, user_values)
+impl Default for AiProviderSettingsContent {
+    fn default() -> Self {
+        Self::OpenAi(OpenAiProviderSettingsContent::default())
     }
 }
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct OpenAiProviderSettings {
+    /// The OpenAI API base URL to use when starting new conversations.
+    pub api_url: Option<String>,
+    /// The default OpenAI model to use when starting new conversations.
+    pub default_model: Option<OpenAiModel>,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct OpenAiProviderSettingsContent {
+    /// The OpenAI API base URL to use when starting new conversations.
+    ///
+    /// Default: https://api.openai.com/v1
+    pub api_url: Option<String>,
+    /// The default OpenAI model to use when starting new conversations.
+    ///
+    /// Default: gpt-4-1106-preview
+    pub default_model: Option<OpenAiModel>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct AzureOpenAiProviderSettings {
+    /// The Azure OpenAI API base URL to use when starting new conversations.
+    pub api_url: Option<String>,
+    /// The Azure OpenAI API version.
+    pub api_version: Option<AzureOpenAiApiVersion>,
+    /// The Azure OpenAI API deployment ID.
+    pub deployment_id: Option<String>,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct AzureOpenAiProviderSettingsContent {
+    /// The Azure OpenAI API base URL to use when starting new conversations.
+    pub api_url: Option<String>,
+    /// The Azure OpenAI API version.
+    pub api_version: Option<AzureOpenAiApiVersion>,
+    /// The Azure OpenAI deployment ID.
+    pub deployment_id: Option<String>,
+}

crates/client/src/telemetry.rs 🔗

@@ -263,7 +263,7 @@ impl Telemetry {
         self: &Arc<Self>,
         conversation_id: Option<String>,
         kind: AssistantKind,
-        model: &'static str,
+        model: &str,
     ) {
         let event = Event::Assistant(AssistantEvent {
             conversation_id,