Stub out support for Azure OpenAI (#8624)

Marshall Bowers created

This PR stubs out support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
within the `OpenAiCompletionProvider`.

It still requires some additional wiring so that it is accessible, but
the necessary hooks should be in place now.

Release Notes:

- N/A

Change summary

crates/ai/src/providers.rs                    |  0 
crates/ai/src/providers/open_ai/completion.rs | 58 +++++++++++++++++++-
crates/assistant/src/assistant_panel.rs       |  8 ++
3 files changed, 59 insertions(+), 7 deletions(-)

Detailed changes

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -102,8 +102,9 @@ pub struct OpenAiResponseStreamEvent {
     pub usage: Option<OpenAiUsage>,
 }
 
-pub async fn stream_completion(
+async fn stream_completion(
     api_url: String,
+    kind: OpenAiCompletionProviderKind,
     credential: ProviderCredential,
     executor: BackgroundExecutor,
     request: Box<dyn CompletionRequest>,
@@ -117,10 +118,11 @@ pub async fn stream_completion(
 
     let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
 
+    let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
     let json_data = request.data()?;
-    let mut response = Request::post(format!("{api_url}/chat/completions"))
+    let mut response = Request::post(kind.completions_endpoint_url(&api_url))
         .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
+        .header(auth_header_name, auth_header_value)
         .body(json_data)?
         .send_async()
         .await?;
@@ -194,22 +196,65 @@ pub async fn stream_completion(
     }
 }
 
+#[derive(Clone)]
+pub enum OpenAiCompletionProviderKind {
+    OpenAi,
+    AzureOpenAi {
+        deployment_id: String,
+        api_version: String,
+    },
+}
+
+impl OpenAiCompletionProviderKind {
+    /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
+    fn completions_endpoint_url(&self, api_url: &str) -> String {
+        match self {
+            Self::OpenAi => {
+                // https://platform.openai.com/docs/api-reference/chat/create
+                format!("{api_url}/chat/completions")
+            }
+            Self::AzureOpenAi {
+                deployment_id,
+                api_version,
+            } => {
+                // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
+                format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
+            }
+        }
+    }
+
+    /// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
+    fn auth_header(&self, api_key: String) -> (&'static str, String) {
+        match self {
+            Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
+            Self::AzureOpenAi { .. } => ("Api-Key", api_key),
+        }
+    }
+}
+
 #[derive(Clone)]
 pub struct OpenAiCompletionProvider {
     api_url: String,
+    kind: OpenAiCompletionProviderKind,
     model: OpenAiLanguageModel,
     credential: Arc<RwLock<ProviderCredential>>,
     executor: BackgroundExecutor,
 }
 
 impl OpenAiCompletionProvider {
-    pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
+    pub async fn new(
+        api_url: String,
+        kind: OpenAiCompletionProviderKind,
+        model_name: String,
+        executor: BackgroundExecutor,
+    ) -> Self {
         let model = executor
             .spawn(async move { OpenAiLanguageModel::load(&model_name) })
             .await;
         let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
         Self {
             api_url,
+            kind,
             model,
             credential,
             executor,
@@ -297,6 +342,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
         let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
         model
     }
+
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,
@@ -307,7 +353,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
         // At some point in the future we should rectify this.
         let credential = self.credential.read().clone();
         let api_url = self.api_url.clone();
-        let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
+        let kind = self.kind.clone();
+        let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
         async move {
             let response = request.await?;
             let stream = response
@@ -322,6 +369,7 @@ impl CompletionProvider for OpenAiCompletionProvider {
         }
         .boxed()
     }
+
     fn box_clone(&self) -> Box<dyn CompletionProvider> {
         Box::new((*self).clone())
     }

crates/assistant/src/assistant_panel.rs 🔗

@@ -7,11 +7,13 @@ use crate::{
     SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
 };
 use ai::prompts::repository_context::PromptCodeSnippet;
-use ai::providers::open_ai::OPEN_AI_API_URL;
 use ai::{
     auth::ProviderCredential,
     completion::{CompletionProvider, CompletionRequest},
-    providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage},
+    providers::open_ai::{
+        OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage,
+        OPEN_AI_API_URL,
+    },
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -131,6 +133,7 @@ impl AssistantPanel {
             })?;
             let completion_provider = OpenAiCompletionProvider::new(
                 api_url,
+                OpenAiCompletionProviderKind::OpenAi,
                 model_name,
                 cx.background_executor().clone(),
             )
@@ -1533,6 +1536,7 @@ impl Conversation {
                 api_url
                     .clone()
                     .unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
+                OpenAiCompletionProviderKind::OpenAi,
                 model.full_name().into(),
                 cx.background_executor().clone(),
             )