Allow OpenAI API URL to be configured via `assistant.openai_api_url` (#7552)

Yesterday17 and Marshall Bowers created

Partially fixes #4321, since Azure OpenAI API can be converted to OpenAI
API.

Release Notes:

- Added `assistant.openai_api_url` setting to allow OpenAI API URL to be
configured.

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

assets/settings/default.json                  |  2 +
crates/ai/src/providers/open_ai/completion.rs | 10 ++++--
crates/ai/src/providers/open_ai/embedding.rs  | 13 +++++++-
crates/assistant/src/assistant.rs             |  1 
crates/assistant/src/assistant_panel.rs       | 30 ++++++++++++++++++--
crates/assistant/src/assistant_settings.rs    |  5 +++
crates/semantic_index/src/semantic_index.rs   | 11 +++++--
7 files changed, 60 insertions(+), 12 deletions(-)

Detailed changes

assets/settings/default.json 🔗

@@ -212,6 +212,8 @@
     "default_width": 640,
     // Default height when the assistant is docked to the bottom.
     "default_height": 320,
+    // The default OpenAI API endpoint to use when starting new conversations.
+    "openai_api_url": "https://api.openai.com/v1",
     // The default OpenAI model to use when starting new conversations. This
     // setting can take three values:
     //

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent {
 }
 
 pub async fn stream_completion(
+    api_url: String,
     credential: ProviderCredential,
     executor: BackgroundExecutor,
     request: Box<dyn CompletionRequest>,
@@ -117,7 +118,7 @@ pub async fn stream_completion(
     let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
 
     let json_data = request.data()?;
-    let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
+    let mut response = Request::post(format!("{api_url}/chat/completions"))
         .header("Content-Type", "application/json")
         .header("Authorization", format!("Bearer {}", api_key))
         .body(json_data)?
@@ -195,18 +196,20 @@ pub async fn stream_completion(
 
 #[derive(Clone)]
 pub struct OpenAiCompletionProvider {
+    api_url: String,
     model: OpenAiLanguageModel,
     credential: Arc<RwLock<ProviderCredential>>,
     executor: BackgroundExecutor,
 }
 
 impl OpenAiCompletionProvider {
-    pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
+    pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
         let model = executor
             .spawn(async move { OpenAiLanguageModel::load(&model_name) })
             .await;
         let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
         Self {
+            api_url,
             model,
             credential,
             executor,
@@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
         // which is currently model based, due to the language model.
         // At some point in the future we should rectify this.
         let credential = self.credential.read().clone();
-        let request = stream_completion(credential, self.executor.clone(), prompt);
+        let api_url = self.api_url.clone();
+        let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
         async move {
             let response = request.await?;
             let stream = response

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -35,6 +35,7 @@ lazy_static! {
 
 #[derive(Clone)]
 pub struct OpenAiEmbeddingProvider {
+    api_url: String,
     model: OpenAiLanguageModel,
     credential: Arc<RwLock<ProviderCredential>>,
     pub client: Arc<dyn HttpClient>,
@@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage {
 }
 
 impl OpenAiEmbeddingProvider {
-    pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
+    pub async fn new(
+        api_url: String,
+        client: Arc<dyn HttpClient>,
+        executor: BackgroundExecutor,
+    ) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
@@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider {
         let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 
         OpenAiEmbeddingProvider {
+            api_url,
             model,
             credential,
             client,
@@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider {
     }
     async fn send_request(
         &self,
+        api_url: &str,
         api_key: &str,
         spans: Vec<&str>,
         request_timeout: u64,
     ) -> Result<Response<AsyncBody>> {
-        let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings"))
+        let request = Request::post(format!("{api_url}/embeddings"))
             .redirect_policy(isahc::config::RedirectPolicy::Follow)
             .timeout(Duration::from_secs(request_timeout))
             .header("Content-Type", "application/json")
@@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
+        let api_url = self.api_url.as_str();
         let api_key = self.get_api_key()?;
 
         let mut request_number = 0;
@@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
         while request_number < MAX_RETRIES {
             response = self
                 .send_request(
+                    &api_url,
                     &api_key,
                     spans.iter().map(|x| &**x).collect(),
                     request_timeout,

crates/assistant/src/assistant.rs 🔗

@@ -68,6 +68,7 @@ struct SavedConversation {
     messages: Vec<SavedMessage>,
     message_metadata: HashMap<MessageId, MessageMetadata>,
     summary: String,
+    api_url: Option<String>,
     model: OpenAiModel,
 }
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -7,6 +7,7 @@ use crate::{
     SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
 };
 use ai::prompts::repository_context::PromptCodeSnippet;
+use ai::providers::open_ai::OPEN_AI_API_URL;
 use ai::{
     auth::ProviderCredential,
     completion::{CompletionProvider, CompletionRequest},
@@ -121,10 +122,22 @@ impl AssistantPanel {
                 .await
                 .log_err()
                 .unwrap_or_default();
-            // Defaulting currently to GPT4, allow for this to be set via config.
-            let completion_provider =
-                OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
-                    .await;
+            let (api_url, model_name) = cx
+                .update(|cx| {
+                    let settings = AssistantSettings::get_global(cx);
+                    (
+                        settings.openai_api_url.clone(),
+                        settings.default_open_ai_model.full_name().to_string(),
+                    )
+                })
+                .log_err()
+                .unwrap();
+            let completion_provider = OpenAiCompletionProvider::new(
+                api_url,
+                model_name,
+                cx.background_executor().clone(),
+            )
+            .await;
 
             // TODO: deserialize state.
             let workspace_handle = workspace.clone();
@@ -1407,6 +1420,7 @@ struct Conversation {
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     model: OpenAiModel,
+    api_url: Option<String>,
     token_count: Option<usize>,
     max_token_count: usize,
     pending_token_count: Task<Option<()>>,
@@ -1441,6 +1455,7 @@ impl Conversation {
 
         let settings = AssistantSettings::get_global(cx);
         let model = settings.default_open_ai_model.clone();
+        let api_url = settings.openai_api_url.clone();
 
         let mut this = Self {
             id: Some(Uuid::new_v4().to_string()),
@@ -1454,6 +1469,7 @@ impl Conversation {
             token_count: None,
             max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
             pending_token_count: Task::ready(None),
+            api_url: Some(api_url),
             model: model.clone(),
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
@@ -1499,6 +1515,7 @@ impl Conversation {
                 .map(|summary| summary.text.clone())
                 .unwrap_or_default(),
             model: self.model.clone(),
+            api_url: self.api_url.clone(),
         }
     }
 
@@ -1513,8 +1530,12 @@ impl Conversation {
             None => Some(Uuid::new_v4().to_string()),
         };
         let model = saved_conversation.model;
+        let api_url = saved_conversation.api_url;
         let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
             OpenAiCompletionProvider::new(
+                api_url
+                    .clone()
+                    .unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
                 model.full_name().into(),
                 cx.background_executor().clone(),
             )
@@ -1567,6 +1588,7 @@ impl Conversation {
                 token_count: None,
                 max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
                 pending_token_count: Task::ready(None),
+                api_url,
                 model,
                 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
                 pending_save: Task::ready(Ok(())),

crates/assistant/src/assistant_settings.rs 🔗

@@ -55,6 +55,7 @@ pub struct AssistantSettings {
     pub default_width: Pixels,
     pub default_height: Pixels,
     pub default_open_ai_model: OpenAiModel,
+    pub openai_api_url: String,
 }
 
 /// Assistant panel settings
@@ -80,6 +81,10 @@ pub struct AssistantSettingsContent {
     ///
     /// Default: gpt-4-1106-preview
     pub default_open_ai_model: Option<OpenAiModel>,
+    /// OpenAI API base URL to use when starting new conversations.
+    ///
+    /// Default: https://api.openai.com/v1
+    pub openai_api_url: Option<String>,
 }
 
 impl Settings for AssistantSettings {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -8,7 +8,7 @@ mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
 use ai::embedding::{Embedding, EmbeddingProvider};
-use ai::providers::open_ai::OpenAiEmbeddingProvider;
+use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
 use anyhow::{anyhow, Context as _, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -91,8 +91,13 @@ pub fn init(
     .detach();
 
     cx.spawn(move |cx| async move {
-        let embedding_provider =
-            OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
+        let embedding_provider = OpenAiEmbeddingProvider::new(
+            // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
+            OPEN_AI_API_URL.to_string(),
+            http_client,
+            cx.background_executor().clone(),
+        )
+        .await;
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,