assistant2.rs

  1mod assistant_settings;
  2mod completion_provider;
  3mod tools;
  4mod ui;
  5
  6use ::ui::{div, prelude::*, Color, ViewContext};
  7use anyhow::{Context, Result};
  8use assistant_tooling::{ToolFunctionCall, ToolRegistry};
  9use client::{proto, Client, UserStore};
 10use collections::HashMap;
 11use completion_provider::*;
 12use editor::Editor;
 13use feature_flags::FeatureFlagAppExt as _;
 14use futures::{future::join_all, StreamExt};
 15use gpui::{
 16    list, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, FocusableView,
 17    ListAlignment, ListState, Model, Render, Task, View, WeakView,
 18};
 19use language::{language_settings::SoftWrap, LanguageRegistry};
 20use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 21use rich_text::RichText;
 22use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
 23use serde::Deserialize;
 24use settings::Settings;
 25use std::sync::Arc;
 26use ui::Composer;
 27use util::{paths::EMBEDDINGS_DIR, ResultExt};
 28use workspace::{
 29    dock::{DockPosition, Panel, PanelEvent},
 30    Workspace,
 31};
 32
 33pub use assistant_settings::AssistantSettings;
 34
 35use crate::tools::{CreateBufferTool, ProjectIndexTool};
 36use crate::ui::UserOrAssistant;
 37
 38const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 39
 40#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 41pub struct Submit(SubmitMode);
 42
 43/// There are multiple different ways to submit a model request, represented by this enum.
 44#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 45pub enum SubmitMode {
 46    /// Only include the conversation.
 47    Simple,
 48    /// Send the current file as context.
 49    CurrentFile,
 50    /// Search the codebase and send relevant excerpts.
 51    Codebase,
 52}
 53
 54gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex]);
 55gpui::impl_actions!(assistant2, [Submit]);
 56
 57pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 58    AssistantSettings::register(cx);
 59
 60    cx.spawn(|mut cx| {
 61        let client = client.clone();
 62        async move {
 63            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
 64            let semantic_index = SemanticIndex::new(
 65                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
 66                Arc::new(embedding_provider),
 67                &mut cx,
 68            )
 69            .await?;
 70            cx.update(|cx| cx.set_global(semantic_index))
 71        }
 72    })
 73    .detach();
 74
 75    cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
 76        client,
 77    )));
 78
 79    cx.observe_new_views(
 80        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 81            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 82                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 83            });
 84        },
 85    )
 86    .detach();
 87}
 88
 89pub fn enabled(cx: &AppContext) -> bool {
 90    cx.is_staff()
 91}
 92
 93pub struct AssistantPanel {
 94    chat: View<AssistantChat>,
 95    width: Option<Pixels>,
 96}
 97
 98impl AssistantPanel {
 99    pub fn load(
100        workspace: WeakView<Workspace>,
101        cx: AsyncWindowContext,
102    ) -> Task<Result<View<Self>>> {
103        cx.spawn(|mut cx| async move {
104            let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
105                (workspace.app_state().clone(), workspace.project().clone())
106            })?;
107
108            let user_store = app_state.user_store.clone();
109
110            cx.new_view(|cx| {
111                // todo!("this will panic if the semantic index failed to load or has not loaded yet")
112                let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
113                    semantic_index.project_index(project.clone(), cx)
114                });
115
116                let mut tool_registry = ToolRegistry::new();
117                tool_registry
118                    .register(
119                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
120                        cx,
121                    )
122                    .context("failed to register ProjectIndexTool")
123                    .log_err();
124                tool_registry
125                    .register(
126                        CreateBufferTool::new(workspace.clone(), project.clone()),
127                        cx,
128                    )
129                    .context("failed to register CreateBufferTool")
130                    .log_err();
131
132                let tool_registry = Arc::new(tool_registry);
133
134                Self::new(
135                    app_state.languages.clone(),
136                    tool_registry,
137                    user_store,
138                    Some(project_index),
139                    cx,
140                )
141            })
142        })
143    }
144
145    pub fn new(
146        language_registry: Arc<LanguageRegistry>,
147        tool_registry: Arc<ToolRegistry>,
148        user_store: Model<UserStore>,
149        project_index: Option<Model<ProjectIndex>>,
150        cx: &mut ViewContext<Self>,
151    ) -> Self {
152        let chat = cx.new_view(|cx| {
153            AssistantChat::new(
154                language_registry.clone(),
155                tool_registry.clone(),
156                user_store,
157                project_index,
158                cx,
159            )
160        });
161
162        Self { width: None, chat }
163    }
164}
165
166impl Render for AssistantPanel {
167    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
168        div()
169            .size_full()
170            .v_flex()
171            .p_2()
172            .bg(cx.theme().colors().background)
173            .child(self.chat.clone())
174    }
175}
176
177impl Panel for AssistantPanel {
178    fn persistent_name() -> &'static str {
179        "AssistantPanelv2"
180    }
181
182    fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
183        // todo!("Add a setting / use assistant settings")
184        DockPosition::Right
185    }
186
187    fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
188        matches!(position, DockPosition::Right)
189    }
190
191    fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
192        // Do nothing until we have a setting for this
193    }
194
195    fn size(&self, _cx: &WindowContext) -> Pixels {
196        self.width.unwrap_or(px(400.))
197    }
198
199    fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
200        self.width = size;
201        cx.notify();
202    }
203
204    fn icon(&self, _cx: &WindowContext) -> Option<::ui::IconName> {
205        Some(IconName::Ai)
206    }
207
208    fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
209        Some("Assistant Panel ✨")
210    }
211
212    fn toggle_action(&self) -> Box<dyn gpui::Action> {
213        Box::new(ToggleFocus)
214    }
215}
216
217impl EventEmitter<PanelEvent> for AssistantPanel {}
218
219impl FocusableView for AssistantPanel {
220    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
221        self.chat.read(cx).composer_editor.read(cx).focus_handle(cx)
222    }
223}
224
225struct AssistantChat {
226    model: String,
227    messages: Vec<ChatMessage>,
228    list_state: ListState,
229    language_registry: Arc<LanguageRegistry>,
230    composer_editor: View<Editor>,
231    user_store: Model<UserStore>,
232    next_message_id: MessageId,
233    collapsed_messages: HashMap<MessageId, bool>,
234    pending_completion: Option<Task<()>>,
235    tool_registry: Arc<ToolRegistry>,
236    project_index: Option<Model<ProjectIndex>>,
237}
238
239impl AssistantChat {
240    fn new(
241        language_registry: Arc<LanguageRegistry>,
242        tool_registry: Arc<ToolRegistry>,
243        user_store: Model<UserStore>,
244        project_index: Option<Model<ProjectIndex>>,
245        cx: &mut ViewContext<Self>,
246    ) -> Self {
247        let model = CompletionProvider::get(cx).default_model();
248        let view = cx.view().downgrade();
249        let list_state = ListState::new(
250            0,
251            ListAlignment::Bottom,
252            px(1024.),
253            move |ix, cx: &mut WindowContext| {
254                view.update(cx, |this, cx| this.render_message(ix, cx))
255                    .unwrap()
256            },
257        );
258
259        Self {
260            model,
261            messages: Vec::new(),
262            composer_editor: cx.new_view(|cx| {
263                let mut editor = Editor::auto_height(80, cx);
264                editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
265                editor.set_placeholder_text("Type a message to the assistant", cx);
266                editor
267            }),
268            list_state,
269            user_store,
270            language_registry,
271            project_index,
272            next_message_id: MessageId(0),
273            collapsed_messages: HashMap::default(),
274            pending_completion: None,
275            tool_registry,
276        }
277    }
278
279    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
280        self.messages.iter().find_map(|message| match message {
281            ChatMessage::User(message) => message
282                .body
283                .focus_handle(cx)
284                .contains_focused(cx)
285                .then_some(message.id),
286            ChatMessage::Assistant(_) => None,
287        })
288    }
289
290    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
291        if self.pending_completion.take().is_none() {
292            cx.propagate();
293            return;
294        }
295
296        if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
297            if message.body.text.is_empty() {
298                self.pop_message(cx);
299            }
300        }
301    }
302
303    fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
304        // Don't allow multiple concurrent completions.
305        if self.pending_completion.is_some() {
306            cx.propagate();
307            return;
308        }
309
310        if let Some(focused_message_id) = self.focused_message_id(cx) {
311            self.truncate_messages(focused_message_id, cx);
312        } else if self.composer_editor.focus_handle(cx).is_focused(cx) {
313            let message = self.composer_editor.update(cx, |composer_editor, cx| {
314                let text = composer_editor.text(cx);
315                let id = self.next_message_id.post_inc();
316                let body = cx.new_view(|cx| {
317                    let mut editor = Editor::auto_height(80, cx);
318                    editor.set_text(text, cx);
319                    editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
320                    editor
321                });
322                composer_editor.clear(cx);
323                ChatMessage::User(UserMessage { id, body })
324            });
325            self.push_message(message, cx);
326        } else {
327            log::error!("unexpected state: no user message editor is focused.");
328            return;
329        }
330
331        let mode = *mode;
332        self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
333            Self::request_completion(
334                this.clone(),
335                mode,
336                MAX_COMPLETION_CALLS_PER_SUBMISSION,
337                &mut cx,
338            )
339            .await
340            .log_err();
341
342            this.update(&mut cx, |this, cx| {
343                let composer_focus_handle = this.composer_editor.focus_handle(cx);
344                cx.focus(&composer_focus_handle);
345                this.pending_completion = None;
346            })
347            .context("Failed to push new user message")
348            .log_err();
349        }));
350    }
351
352    fn can_submit(&self) -> bool {
353        self.pending_completion.is_none()
354    }
355
356    fn debug_project_index(&mut self, _: &DebugProjectIndex, cx: &mut ViewContext<Self>) {
357        if let Some(index) = &self.project_index {
358            index.update(cx, |project_index, cx| {
359                project_index.debug(cx).detach_and_log_err(cx)
360            });
361        }
362    }
363
364    async fn request_completion(
365        this: WeakView<Self>,
366        mode: SubmitMode,
367        limit: usize,
368        cx: &mut AsyncWindowContext,
369    ) -> Result<()> {
370        let mut call_count = 0;
371        loop {
372            let complete = async {
373                let completion = this.update(cx, |this, cx| {
374                    this.push_new_assistant_message(cx);
375
376                    let definitions = if call_count < limit
377                        && matches!(mode, SubmitMode::Codebase | SubmitMode::Simple)
378                    {
379                        this.tool_registry.definitions()
380                    } else {
381                        &[]
382                    };
383                    call_count += 1;
384
385                    let messages = this.completion_messages(cx);
386
387                    CompletionProvider::get(cx).complete(
388                        this.model.clone(),
389                        messages,
390                        Vec::new(),
391                        1.0,
392                        definitions,
393                    )
394                });
395
396                let mut stream = completion?.await?;
397                let mut body = String::new();
398                while let Some(delta) = stream.next().await {
399                    let delta = delta?;
400                    this.update(cx, |this, cx| {
401                        if let Some(ChatMessage::Assistant(AssistantMessage {
402                            body: message_body,
403                            tool_calls: message_tool_calls,
404                            ..
405                        })) = this.messages.last_mut()
406                        {
407                            if let Some(content) = &delta.content {
408                                body.push_str(content);
409                            }
410
411                            for tool_call in delta.tool_calls {
412                                let index = tool_call.index as usize;
413                                if index >= message_tool_calls.len() {
414                                    message_tool_calls.resize_with(index + 1, Default::default);
415                                }
416                                let call = &mut message_tool_calls[index];
417
418                                if let Some(id) = &tool_call.id {
419                                    call.id.push_str(id);
420                                }
421
422                                match tool_call.variant {
423                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
424                                        if let Some(name) = &tool_call.name {
425                                            call.name.push_str(name);
426                                        }
427                                        if let Some(arguments) = &tool_call.arguments {
428                                            call.arguments.push_str(arguments);
429                                        }
430                                    }
431                                    None => {}
432                                }
433                            }
434
435                            *message_body =
436                                RichText::new(body.clone(), &[], &this.language_registry);
437                            cx.notify();
438                        } else {
439                            unreachable!()
440                        }
441                    })?;
442                }
443
444                anyhow::Ok(())
445            }
446            .await;
447
448            let mut tool_tasks = Vec::new();
449            this.update(cx, |this, cx| {
450                if let Some(ChatMessage::Assistant(AssistantMessage {
451                    error: message_error,
452                    tool_calls,
453                    ..
454                })) = this.messages.last_mut()
455                {
456                    if let Err(error) = complete {
457                        message_error.replace(SharedString::from(error.to_string()));
458                        cx.notify();
459                    } else {
460                        for tool_call in tool_calls.iter() {
461                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
462                        }
463                    }
464                }
465            })?;
466
467            if tool_tasks.is_empty() {
468                return Ok(());
469            }
470
471            let tools = join_all(tool_tasks.into_iter()).await;
472            // If the WindowContext went away for any tool's view we don't include it
473            // especially since the below call would fail for the same reason.
474            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
475
476            this.update(cx, |this, cx| {
477                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
478                    this.messages.last_mut()
479                {
480                    *tool_calls = tools;
481                    cx.notify();
482                }
483            })?;
484        }
485    }
486
487    fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
488        let message = ChatMessage::Assistant(AssistantMessage {
489            id: self.next_message_id.post_inc(),
490            body: RichText::default(),
491            tool_calls: Vec::new(),
492            error: None,
493        });
494        self.push_message(message, cx);
495    }
496
497    fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
498        let old_len = self.messages.len();
499        let focus_handle = Some(message.focus_handle(cx));
500        self.messages.push(message);
501        self.list_state
502            .splice_focusable(old_len..old_len, focus_handle);
503        cx.notify();
504    }
505
506    fn pop_message(&mut self, cx: &mut ViewContext<Self>) {
507        if self.messages.is_empty() {
508            return;
509        }
510
511        self.messages.pop();
512        self.list_state
513            .splice(self.messages.len()..self.messages.len() + 1, 0);
514        cx.notify();
515    }
516
517    fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
518        if let Some(index) = self.messages.iter().position(|message| match message {
519            ChatMessage::User(message) => message.id == last_message_id,
520            ChatMessage::Assistant(message) => message.id == last_message_id,
521        }) {
522            self.list_state.splice(index + 1..self.messages.len(), 0);
523            self.messages.truncate(index + 1);
524            cx.notify();
525        }
526    }
527
528    fn is_message_collapsed(&self, id: &MessageId) -> bool {
529        self.collapsed_messages.get(id).copied().unwrap_or_default()
530    }
531
532    fn toggle_message_collapsed(&mut self, id: MessageId) {
533        let entry = self.collapsed_messages.entry(id).or_insert(false);
534        *entry = !*entry;
535    }
536
537    fn render_error(
538        &self,
539        error: Option<SharedString>,
540        _ix: usize,
541        cx: &mut ViewContext<Self>,
542    ) -> AnyElement {
543        let theme = cx.theme();
544
545        if let Some(error) = error {
546            div()
547                .py_1()
548                .px_2()
549                .neg_mx_1()
550                .rounded_md()
551                .border()
552                .border_color(theme.status().error_border)
553                // .bg(theme.status().error_background)
554                .text_color(theme.status().error)
555                .child(error.clone())
556                .into_any_element()
557        } else {
558            div().into_any_element()
559        }
560    }
561
562    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
563        let is_last = ix == self.messages.len() - 1;
564
565        match &self.messages[ix] {
566            ChatMessage::User(UserMessage { id, body }) => div()
567                .when(!is_last, |element| element.mb_2())
568                .child(crate::ui::ChatMessage::new(
569                    *id,
570                    UserOrAssistant::User(self.user_store.read(cx).current_user()),
571                    Some(body.clone().into_any_element()),
572                    self.is_message_collapsed(id),
573                    Box::new(cx.listener({
574                        let id = *id;
575                        move |assistant_chat, _event, _cx| {
576                            assistant_chat.toggle_message_collapsed(id)
577                        }
578                    })),
579                ))
580                .into_any(),
581            ChatMessage::Assistant(AssistantMessage {
582                id,
583                body,
584                error,
585                tool_calls,
586                ..
587            }) => {
588                let assistant_body = if body.text.is_empty() {
589                    None
590                } else {
591                    Some(
592                        div()
593                            .p_2()
594                            .child(body.element(ElementId::from(id.0), cx))
595                            .into_any_element(),
596                    )
597                };
598
599                div()
600                    .when(!is_last, |element| element.mb_2())
601                    .child(crate::ui::ChatMessage::new(
602                        *id,
603                        UserOrAssistant::Assistant,
604                        assistant_body,
605                        self.is_message_collapsed(id),
606                        Box::new(cx.listener({
607                            let id = *id;
608                            move |assistant_chat, _event, _cx| {
609                                assistant_chat.toggle_message_collapsed(id)
610                            }
611                        })),
612                    ))
613                    // TODO: Should the errors and tool calls get passed into `ChatMessage`?
614                    .child(self.render_error(error.clone(), ix, cx))
615                    .children(tool_calls.iter().map(|tool_call| {
616                        let result = &tool_call.result;
617                        let name = tool_call.name.clone();
618                        match result {
619                            Some(result) => {
620                                div().p_2().child(result.into_any_element(&name)).into_any()
621                            }
622                            None => div()
623                                .p_2()
624                                .child(Label::new(name).color(Color::Modified))
625                                .child("Running...")
626                                .into_any(),
627                        }
628                    }))
629                    .into_any()
630            }
631        }
632    }
633
634    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
635        let mut completion_messages = Vec::new();
636
637        for message in &self.messages {
638            match message {
639                ChatMessage::User(UserMessage { body, .. }) => {
640                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
641                    // contexts.iter().for_each(|context| {
642                    //     completion_messages.extend(context.completion_messages(cx))
643                    // });
644
645                    // Show user's message last so that the assistant is grounded in the user's request
646                    completion_messages.push(CompletionMessage::User {
647                        content: body.read(cx).text(cx),
648                    });
649                }
650                ChatMessage::Assistant(AssistantMessage {
651                    body, tool_calls, ..
652                }) => {
653                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
654                    // not break the Chat API if it does.
655                    if body.text.is_empty() && tool_calls.is_empty() {
656                        continue;
657                    }
658
659                    let tool_calls_from_assistant = tool_calls
660                        .iter()
661                        .map(|tool_call| ToolCall {
662                            content: ToolCallContent::Function {
663                                function: FunctionContent {
664                                    name: tool_call.name.clone(),
665                                    arguments: tool_call.arguments.clone(),
666                                },
667                            },
668                            id: tool_call.id.clone(),
669                        })
670                        .collect();
671
672                    completion_messages.push(CompletionMessage::Assistant {
673                        content: Some(body.text.to_string()),
674                        tool_calls: tool_calls_from_assistant,
675                    });
676
677                    for tool_call in tool_calls {
678                        // todo!(): we should not be sending when the tool is still running / has no result
679                        // For now I'm going to have to assume we send an empty string because otherwise
680                        // the Chat API will break -- there is a required message for every tool call by ID
681                        let content = match &tool_call.result {
682                            Some(result) => result.format(&tool_call.name),
683                            None => "".to_string(),
684                        };
685
686                        completion_messages.push(CompletionMessage::Tool {
687                            content,
688                            tool_call_id: tool_call.id.clone(),
689                        });
690                    }
691                }
692            }
693        }
694
695        completion_messages
696    }
697}
698
699impl Render for AssistantChat {
700    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
701        div()
702            .relative()
703            .flex_1()
704            .v_flex()
705            .key_context("AssistantChat")
706            .on_action(cx.listener(Self::submit))
707            .on_action(cx.listener(Self::cancel))
708            .on_action(cx.listener(Self::debug_project_index))
709            .text_color(Color::Default.color(cx))
710            .child(list(self.list_state.clone()).flex_1())
711            .child(Composer::new(
712                cx.view().downgrade(),
713                self.model.clone(),
714                self.composer_editor.clone(),
715                self.user_store.read(cx).current_user(),
716                self.can_submit(),
717                self.tool_registry.clone(),
718            ))
719    }
720}
721
722#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
723struct MessageId(usize);
724
725impl MessageId {
726    fn post_inc(&mut self) -> Self {
727        let id = *self;
728        self.0 += 1;
729        id
730    }
731}
732
733enum ChatMessage {
734    User(UserMessage),
735    Assistant(AssistantMessage),
736}
737
738impl ChatMessage {
739    fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
740        match self {
741            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
742            ChatMessage::Assistant(_) => None,
743        }
744    }
745}
746
747struct UserMessage {
748    id: MessageId,
749    body: View<Editor>,
750}
751
752struct AssistantMessage {
753    id: MessageId,
754    body: RichText,
755    tool_calls: Vec<ToolFunctionCall>,
756    error: Option<SharedString>,
757}