@@ -31,9 +31,9 @@ use fs::Fs;
use futures::StreamExt;
use gpui::{
canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
- AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, FocusHandle,
- FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
- ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
+ AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
+ FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
+ IntoElement, Model, ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
};
@@ -123,6 +123,10 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
+ // Defaulting currently to GPT4, allow for this to be set via config.
+ let completion_provider =
+ OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
+ .await;
// TODO: deserialize state.
let workspace_handle = workspace.clone();
@@ -156,11 +160,6 @@ impl AssistantPanel {
});
let semantic_index = SemanticIndex::global(cx);
- // Defaulting currently to GPT4, allow for this to be set via config.
- let completion_provider = Arc::new(OpenAICompletionProvider::new(
- "gpt-4",
- cx.background_executor().clone(),
- ));
let focus_handle = cx.focus_handle();
cx.on_focus_in(&focus_handle, Self::focus_in).detach();
@@ -176,7 +175,7 @@ impl AssistantPanel {
zoomed: false,
focus_handle,
toolbar,
- completion_provider,
+ completion_provider: Arc::new(completion_provider),
api_key_editor: None,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
@@ -1079,9 +1078,9 @@ impl AssistantPanel {
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
- let conversation = cx.new_model(|cx| {
- Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
- })?;
+ let conversation =
+ Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
+ .await?;
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
// the same conversation, we don't want to open it again.
@@ -1462,21 +1461,25 @@ impl Conversation {
}
}
- fn deserialize(
+ async fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
+ cx: &mut AsyncAppContext,
+ ) -> Result<Model<Self>> {
let id = match saved_conversation.id {
Some(id) => Some(id),
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
- OpenAICompletionProvider::new(model.full_name(), cx.background_executor().clone()),
+ OpenAICompletionProvider::new(
+ model.full_name().into(),
+ cx.background_executor().clone(),
+ )
+ .await,
);
- completion_provider.retrieve_credentials(cx);
+ cx.update(|cx| completion_provider.retrieve_credentials(cx))?;
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
@@ -1499,32 +1502,34 @@ impl Conversation {
})
.detach_and_log_err(cx);
buffer
- });
-
- let mut this = Self {
- id,
- message_anchors,
- messages_metadata: saved_conversation.message_metadata,
- next_message_id,
- summary: Some(Summary {
- text: saved_conversation.summary,
- done: true,
- }),
- pending_summary: Task::ready(None),
- completion_count: Default::default(),
- pending_completions: Default::default(),
- token_count: None,
- max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
- pending_token_count: Task::ready(None),
- model,
- _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
- pending_save: Task::ready(Ok(())),
- path: Some(path),
- buffer,
- completion_provider,
- };
- this.count_remaining_tokens(cx);
- this
+ })?;
+
+ cx.new_model(|cx| {
+ let mut this = Self {
+ id,
+ message_anchors,
+ messages_metadata: saved_conversation.message_metadata,
+ next_message_id,
+ summary: Some(Summary {
+ text: saved_conversation.summary,
+ done: true,
+ }),
+ pending_summary: Task::ready(None),
+ completion_count: Default::default(),
+ pending_completions: Default::default(),
+ token_count: None,
+ max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
+ pending_token_count: Task::ready(None),
+ model,
+ _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+ pending_save: Task::ready(Ok(())),
+ path: Some(path),
+ buffer,
+ completion_provider,
+ };
+ this.count_remaining_tokens(cx);
+ this
+ })
}
fn handle_buffer_event(
@@ -3169,7 +3174,7 @@ mod tests {
use super::*;
use crate::MessageId;
use ai::test::FakeCompletionProvider;
- use gpui::AppContext;
+ use gpui::{AppContext, TestAppContext};
use settings::SettingsStore;
#[gpui::test]
@@ -3487,16 +3492,17 @@ mod tests {
}
#[gpui::test]
- fn test_serialization(cx: &mut AppContext) {
- let settings_store = SettingsStore::test(cx);
+ async fn test_serialization(cx: &mut TestAppContext) {
+ let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
- init(cx);
+ cx.update(init);
let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Arc::new(FakeCompletionProvider::new());
let conversation =
cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
- let buffer = conversation.read(cx).buffer.clone();
- let message_0 = conversation.read(cx).message_anchors[0].id;
+ let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
+ let message_0 =
+ conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id);
let message_1 = conversation.update(cx, |conversation, cx| {
conversation
.insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
@@ -3517,9 +3523,9 @@ mod tests {
.unwrap()
});
buffer.update(cx, |buffer, cx| buffer.undo(cx));
- assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
+ assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
assert_eq!(
- messages(&conversation, cx),
+ cx.read(|cx| messages(&conversation, cx)),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),
@@ -3527,18 +3533,22 @@ mod tests {
]
);
- let deserialized_conversation = cx.new_model(|cx| {
- Conversation::deserialize(
- conversation.read(cx).serialize(cx),
- Default::default(),
- registry.clone(),
- cx,
- )
- });
- let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
- assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
+ let deserialized_conversation = Conversation::deserialize(
+ conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)),
+ Default::default(),
+ registry.clone(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let deserialized_buffer =
+ deserialized_conversation.read_with(cx, |conversation, _| conversation.buffer.clone());
+ assert_eq!(
+ deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
+ "a\nb\nc\n"
+ );
assert_eq!(
- messages(&deserialized_conversation, cx),
+ cx.read(|cx| messages(&deserialized_conversation, cx)),
[
(message_0, Role::User, 0..2),
(message_1.id, Role::Assistant, 2..6),