Load language models in the background

Antonio Scandurra created

Change summary

crates/ai/src/providers/open_ai/completion.rs |   6 
crates/ai/src/providers/open_ai/embedding.rs  |   7 
crates/assistant/src/assistant_panel.rs       | 134 +++++++++++---------
crates/semantic_index/src/semantic_index.rs   |   7 
4 files changed, 84 insertions(+), 70 deletions(-)

Detailed changes

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -201,8 +201,10 @@ pub struct OpenAICompletionProvider {
 }
 
 impl OpenAICompletionProvider {
-    pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
-        let model = OpenAILanguageModel::load(model_name);
+    pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
+        let model = executor
+            .spawn(async move { OpenAILanguageModel::load(&model_name) })
+            .await;
         let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
         Self {
             model,

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -67,11 +67,14 @@ struct OpenAIEmbeddingUsage {
 }
 
 impl OpenAIEmbeddingProvider {
-    pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
+    pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
-        let model = OpenAILanguageModel::load("text-embedding-ada-002");
+        // Loading the model is expensive, so ensure this runs off the main thread.
+        let model = executor
+            .spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") })
+            .await;
         let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 
         OpenAIEmbeddingProvider {

crates/assistant/src/assistant_panel.rs 🔗

@@ -31,9 +31,9 @@ use fs::Fs;
 use futures::StreamExt;
 use gpui::{
     canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
-    AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, FocusHandle,
-    FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
-    ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
+    AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
+    FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
+    IntoElement, Model, ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
     StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
     View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
 };
@@ -123,6 +123,10 @@ impl AssistantPanel {
                 .await
                 .log_err()
                 .unwrap_or_default();
+            // Defaulting currently to GPT4, allow for this to be set via config.
+            let completion_provider =
+                OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
+                    .await;
 
             // TODO: deserialize state.
             let workspace_handle = workspace.clone();
@@ -156,11 +160,6 @@ impl AssistantPanel {
                     });
 
                     let semantic_index = SemanticIndex::global(cx);
-                    // Defaulting currently to GPT4, allow for this to be set via config.
-                    let completion_provider = Arc::new(OpenAICompletionProvider::new(
-                        "gpt-4",
-                        cx.background_executor().clone(),
-                    ));
 
                     let focus_handle = cx.focus_handle();
                     cx.on_focus_in(&focus_handle, Self::focus_in).detach();
@@ -176,7 +175,7 @@ impl AssistantPanel {
                         zoomed: false,
                         focus_handle,
                         toolbar,
-                        completion_provider,
+                        completion_provider: Arc::new(completion_provider),
                         api_key_editor: None,
                         languages: workspace.app_state().languages.clone(),
                         fs: workspace.app_state().fs.clone(),
@@ -1079,9 +1078,9 @@ impl AssistantPanel {
         cx.spawn(|this, mut cx| async move {
             let saved_conversation = fs.load(&path).await?;
             let saved_conversation = serde_json::from_str(&saved_conversation)?;
-            let conversation = cx.new_model(|cx| {
-                Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
-            })?;
+            let conversation =
+                Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
+                    .await?;
             this.update(&mut cx, |this, cx| {
                 // If, by the time we've loaded the conversation, the user has already opened
                 // the same conversation, we don't want to open it again.
@@ -1462,21 +1461,25 @@ impl Conversation {
         }
     }
 
-    fn deserialize(
+    async fn deserialize(
         saved_conversation: SavedConversation,
         path: PathBuf,
         language_registry: Arc<LanguageRegistry>,
-        cx: &mut ModelContext<Self>,
-    ) -> Self {
+        cx: &mut AsyncAppContext,
+    ) -> Result<Model<Self>> {
         let id = match saved_conversation.id {
             Some(id) => Some(id),
             None => Some(Uuid::new_v4().to_string()),
         };
         let model = saved_conversation.model;
         let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
-            OpenAICompletionProvider::new(model.full_name(), cx.background_executor().clone()),
+            OpenAICompletionProvider::new(
+                model.full_name().into(),
+                cx.background_executor().clone(),
+            )
+            .await,
         );
-        completion_provider.retrieve_credentials(cx);
+        cx.update(|cx| completion_provider.retrieve_credentials(cx))?;
         let markdown = language_registry.language_for_name("Markdown");
         let mut message_anchors = Vec::new();
         let mut next_message_id = MessageId(0);
@@ -1499,32 +1502,34 @@ impl Conversation {
             })
             .detach_and_log_err(cx);
             buffer
-        });
-
-        let mut this = Self {
-            id,
-            message_anchors,
-            messages_metadata: saved_conversation.message_metadata,
-            next_message_id,
-            summary: Some(Summary {
-                text: saved_conversation.summary,
-                done: true,
-            }),
-            pending_summary: Task::ready(None),
-            completion_count: Default::default(),
-            pending_completions: Default::default(),
-            token_count: None,
-            max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
-            pending_token_count: Task::ready(None),
-            model,
-            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
-            pending_save: Task::ready(Ok(())),
-            path: Some(path),
-            buffer,
-            completion_provider,
-        };
-        this.count_remaining_tokens(cx);
-        this
+        })?;
+
+        cx.new_model(|cx| {
+            let mut this = Self {
+                id,
+                message_anchors,
+                messages_metadata: saved_conversation.message_metadata,
+                next_message_id,
+                summary: Some(Summary {
+                    text: saved_conversation.summary,
+                    done: true,
+                }),
+                pending_summary: Task::ready(None),
+                completion_count: Default::default(),
+                pending_completions: Default::default(),
+                token_count: None,
+                max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
+                pending_token_count: Task::ready(None),
+                model,
+                _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+                pending_save: Task::ready(Ok(())),
+                path: Some(path),
+                buffer,
+                completion_provider,
+            };
+            this.count_remaining_tokens(cx);
+            this
+        })
     }
 
     fn handle_buffer_event(
@@ -3169,7 +3174,7 @@ mod tests {
     use super::*;
     use crate::MessageId;
     use ai::test::FakeCompletionProvider;
-    use gpui::AppContext;
+    use gpui::{AppContext, TestAppContext};
     use settings::SettingsStore;
 
     #[gpui::test]
@@ -3487,16 +3492,17 @@ mod tests {
     }
 
     #[gpui::test]
-    fn test_serialization(cx: &mut AppContext) {
-        let settings_store = SettingsStore::test(cx);
+    async fn test_serialization(cx: &mut TestAppContext) {
+        let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
-        init(cx);
+        cx.update(init);
         let registry = Arc::new(LanguageRegistry::test());
         let completion_provider = Arc::new(FakeCompletionProvider::new());
         let conversation =
             cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
-        let buffer = conversation.read(cx).buffer.clone();
-        let message_0 = conversation.read(cx).message_anchors[0].id;
+        let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
+        let message_0 =
+            conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
         let message_1 = conversation.update(cx, |conversation, cx| {
             conversation
                 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
@@ -3517,9 +3523,9 @@ mod tests {
                 .unwrap()
         });
         buffer.update(cx, |buffer, cx| buffer.undo(cx));
-        assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
+        assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
         assert_eq!(
-            messages(&conversation, cx),
+            cx.read(|cx| messages(&conversation, cx)),
             [
                 (message_0, Role::User, 0..2),
                 (message_1.id, Role::Assistant, 2..6),
@@ -3527,18 +3533,22 @@ mod tests {
             ]
         );
 
-        let deserialized_conversation = cx.new_model(|cx| {
-            Conversation::deserialize(
-                conversation.read(cx).serialize(cx),
-                Default::default(),
-                registry.clone(),
-                cx,
-            )
-        });
-        let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
-        assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
+        let deserialized_conversation = Conversation::deserialize(
+            conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
+            Default::default(),
+            registry.clone(),
+            &mut cx.to_async(),
+        )
+        .await
+        .unwrap();
+        let deserialized_buffer =
+            deserialized_conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
+        assert_eq!(
+            deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
+            "a\nb\nc\n"
+        );
         assert_eq!(
-            messages(&deserialized_conversation, cx),
+            cx.read(|cx| messages(&deserialized_conversation, cx)),
             [
                 (message_0, Role::User, 0..2),
                 (message_1.id, Role::Assistant, 2..6),

crates/semantic_index/src/semantic_index.rs 🔗

@@ -90,13 +90,12 @@ pub fn init(
     .detach();
 
     cx.spawn(move |cx| async move {
+        let embedding_provider =
+            OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddingProvider::new(
-                http_client,
-                cx.background_executor().clone(),
-            )),
+            Arc::new(embedding_provider),
             language_registry,
             cx.clone(),
         )