assistant2.rs

  1mod assistant_settings;
  2mod completion_provider;
  3pub mod tools;
  4mod ui;
  5
  6use ::ui::{div, prelude::*, Color, ViewContext};
  7use anyhow::{Context, Result};
  8use assistant_tooling::{ToolFunctionCall, ToolRegistry};
  9use client::{proto, Client, UserStore};
 10use collections::HashMap;
 11use completion_provider::*;
 12use editor::Editor;
 13use feature_flags::FeatureFlagAppExt as _;
 14use futures::{future::join_all, StreamExt};
 15use gpui::{
 16    list, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, FocusableView,
 17    ListAlignment, ListState, Model, Render, Task, View, WeakView,
 18};
 19use language::{language_settings::SoftWrap, LanguageRegistry};
 20use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 21use rich_text::RichText;
 22use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 23use serde::Deserialize;
 24use settings::Settings;
 25use std::sync::Arc;
 26use tools::ProjectIndexTool;
 27use ui::Composer;
 28use util::{paths::EMBEDDINGS_DIR, ResultExt};
 29use workspace::{
 30    dock::{DockPosition, Panel, PanelEvent},
 31    Workspace,
 32};
 33
 34pub use assistant_settings::AssistantSettings;
 35
 36use crate::ui::UserOrAssistant;
 37
 38const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 39
 40#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 41pub struct Submit(SubmitMode);
 42
 43/// There are multiple different ways to submit a model request, represented by this enum.
 44#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 45pub enum SubmitMode {
 46    /// Only include the conversation.
 47    Simple,
 48    /// Send the current file as context.
 49    CurrentFile,
 50    /// Search the codebase and send relevant excerpts.
 51    Codebase,
 52}
 53
 54gpui::actions!(assistant2, [Cancel, ToggleFocus]);
 55gpui::impl_actions!(assistant2, [Submit]);
 56
 57pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 58    AssistantSettings::register(cx);
 59
 60    cx.spawn(|mut cx| {
 61        let client = client.clone();
 62        async move {
 63            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
 64            let semantic_index = SemanticIndex::new(
 65                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
 66                Arc::new(embedding_provider),
 67                &mut cx,
 68            )
 69            .await?;
 70            cx.update(|cx| cx.set_global(semantic_index))
 71        }
 72    })
 73    .detach();
 74
 75    cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
 76        client,
 77    )));
 78
 79    cx.observe_new_views(
 80        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 81            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 82                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 83            });
 84        },
 85    )
 86    .detach();
 87}
 88
 89pub fn enabled(cx: &AppContext) -> bool {
 90    cx.is_staff()
 91}
 92
 93pub struct AssistantPanel {
 94    chat: View<AssistantChat>,
 95    width: Option<Pixels>,
 96}
 97
 98impl AssistantPanel {
 99    pub fn load(
100        workspace: WeakView<Workspace>,
101        cx: AsyncWindowContext,
102    ) -> Task<Result<View<Self>>> {
103        cx.spawn(|mut cx| async move {
104            let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
105                (workspace.app_state().clone(), workspace.project().clone())
106            })?;
107
108            let user_store = app_state.user_store.clone();
109
110            cx.new_view(|cx| {
111                // todo!("this will panic if the semantic index failed to load or has not loaded yet")
112                let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
113                    semantic_index.project_index(project.clone(), cx)
114                });
115
116                let mut tool_registry = ToolRegistry::new();
117                tool_registry
118                    .register(
119                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
120                        cx,
121                    )
122                    .context("failed to register ProjectIndexTool")
123                    .log_err();
124
125                let tool_registry = Arc::new(tool_registry);
126
127                Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
128            })
129        })
130    }
131
132    pub fn new(
133        language_registry: Arc<LanguageRegistry>,
134        tool_registry: Arc<ToolRegistry>,
135        user_store: Model<UserStore>,
136        cx: &mut ViewContext<Self>,
137    ) -> Self {
138        let chat = cx.new_view(|cx| {
139            AssistantChat::new(
140                language_registry.clone(),
141                tool_registry.clone(),
142                user_store,
143                cx,
144            )
145        });
146
147        Self { width: None, chat }
148    }
149}
150
151impl Render for AssistantPanel {
152    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
153        div()
154            .size_full()
155            .v_flex()
156            .p_2()
157            .bg(cx.theme().colors().background)
158            .child(self.chat.clone())
159    }
160}
161
162impl Panel for AssistantPanel {
163    fn persistent_name() -> &'static str {
164        "AssistantPanelv2"
165    }
166
167    fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
168        // todo!("Add a setting / use assistant settings")
169        DockPosition::Right
170    }
171
172    fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
173        matches!(position, DockPosition::Right)
174    }
175
176    fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
177        // Do nothing until we have a setting for this
178    }
179
180    fn size(&self, _cx: &WindowContext) -> Pixels {
181        self.width.unwrap_or(px(400.))
182    }
183
184    fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
185        self.width = size;
186        cx.notify();
187    }
188
189    fn icon(&self, _cx: &WindowContext) -> Option<::ui::IconName> {
190        Some(IconName::Ai)
191    }
192
193    fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
194        Some("Assistant Panel ✨")
195    }
196
197    fn toggle_action(&self) -> Box<dyn gpui::Action> {
198        Box::new(ToggleFocus)
199    }
200}
201
202impl EventEmitter<PanelEvent> for AssistantPanel {}
203
204impl FocusableView for AssistantPanel {
205    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
206        self.chat.read(cx).composer_editor.read(cx).focus_handle(cx)
207    }
208}
209
210struct AssistantChat {
211    model: String,
212    messages: Vec<ChatMessage>,
213    list_state: ListState,
214    language_registry: Arc<LanguageRegistry>,
215    composer_editor: View<Editor>,
216    user_store: Model<UserStore>,
217    next_message_id: MessageId,
218    collapsed_messages: HashMap<MessageId, bool>,
219    pending_completion: Option<Task<()>>,
220    tool_registry: Arc<ToolRegistry>,
221}
222
223impl AssistantChat {
224    fn new(
225        language_registry: Arc<LanguageRegistry>,
226        tool_registry: Arc<ToolRegistry>,
227        user_store: Model<UserStore>,
228        cx: &mut ViewContext<Self>,
229    ) -> Self {
230        let model = CompletionProvider::get(cx).default_model();
231        let view = cx.view().downgrade();
232        let list_state = ListState::new(
233            0,
234            ListAlignment::Bottom,
235            px(1024.),
236            move |ix, cx: &mut WindowContext| {
237                view.update(cx, |this, cx| this.render_message(ix, cx))
238                    .unwrap()
239            },
240        );
241
242        Self {
243            model,
244            messages: Vec::new(),
245            composer_editor: cx.new_view(|cx| {
246                let mut editor = Editor::auto_height(80, cx);
247                editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
248                editor.set_placeholder_text("Type a message to the assistant", cx);
249                editor
250            }),
251            list_state,
252            user_store,
253            language_registry,
254            next_message_id: MessageId(0),
255            collapsed_messages: HashMap::default(),
256            pending_completion: None,
257            tool_registry,
258        }
259    }
260
261    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
262        self.messages.iter().find_map(|message| match message {
263            ChatMessage::User(message) => message
264                .body
265                .focus_handle(cx)
266                .contains_focused(cx)
267                .then_some(message.id),
268            ChatMessage::Assistant(_) => None,
269        })
270    }
271
272    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
273        if self.pending_completion.take().is_none() {
274            cx.propagate();
275            return;
276        }
277
278        if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
279            if message.body.text.is_empty() {
280                self.pop_message(cx);
281            }
282        }
283    }
284
285    fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
286        // Don't allow multiple concurrent completions.
287        if self.pending_completion.is_some() {
288            cx.propagate();
289            return;
290        }
291
292        if let Some(focused_message_id) = self.focused_message_id(cx) {
293            self.truncate_messages(focused_message_id, cx);
294        } else if self.composer_editor.focus_handle(cx).is_focused(cx) {
295            let message = self.composer_editor.update(cx, |composer_editor, cx| {
296                let text = composer_editor.text(cx);
297                let id = self.next_message_id.post_inc();
298                let body = cx.new_view(|cx| {
299                    let mut editor = Editor::auto_height(80, cx);
300                    editor.set_text(text, cx);
301                    editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
302                    editor
303                });
304                composer_editor.clear(cx);
305                ChatMessage::User(UserMessage { id, body })
306            });
307            self.push_message(message, cx);
308        } else {
309            log::error!("unexpected state: no user message editor is focused.");
310            return;
311        }
312
313        let mode = *mode;
314        self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
315            Self::request_completion(
316                this.clone(),
317                mode,
318                MAX_COMPLETION_CALLS_PER_SUBMISSION,
319                &mut cx,
320            )
321            .await
322            .log_err();
323
324            this.update(&mut cx, |this, cx| {
325                let composer_focus_handle = this.composer_editor.focus_handle(cx);
326                cx.focus(&composer_focus_handle);
327                this.pending_completion = None;
328            })
329            .context("Failed to push new user message")
330            .log_err();
331        }));
332    }
333
334    fn can_submit(&self) -> bool {
335        self.pending_completion.is_none()
336    }
337
338    async fn request_completion(
339        this: WeakView<Self>,
340        mode: SubmitMode,
341        limit: usize,
342        cx: &mut AsyncWindowContext,
343    ) -> Result<()> {
344        let mut call_count = 0;
345        loop {
346            let complete = async {
347                let completion = this.update(cx, |this, cx| {
348                    this.push_new_assistant_message(cx);
349
350                    let definitions = if call_count < limit
351                        && matches!(mode, SubmitMode::Codebase | SubmitMode::Simple)
352                    {
353                        this.tool_registry.definitions()
354                    } else {
355                        &[]
356                    };
357                    call_count += 1;
358
359                    let messages = this.completion_messages(cx);
360
361                    CompletionProvider::get(cx).complete(
362                        this.model.clone(),
363                        messages,
364                        Vec::new(),
365                        1.0,
366                        definitions,
367                    )
368                });
369
370                let mut stream = completion?.await?;
371                let mut body = String::new();
372                while let Some(delta) = stream.next().await {
373                    let delta = delta?;
374                    this.update(cx, |this, cx| {
375                        if let Some(ChatMessage::Assistant(AssistantMessage {
376                            body: message_body,
377                            tool_calls: message_tool_calls,
378                            ..
379                        })) = this.messages.last_mut()
380                        {
381                            if let Some(content) = &delta.content {
382                                body.push_str(content);
383                            }
384
385                            for tool_call in delta.tool_calls {
386                                let index = tool_call.index as usize;
387                                if index >= message_tool_calls.len() {
388                                    message_tool_calls.resize_with(index + 1, Default::default);
389                                }
390                                let call = &mut message_tool_calls[index];
391
392                                if let Some(id) = &tool_call.id {
393                                    call.id.push_str(id);
394                                }
395
396                                match tool_call.variant {
397                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
398                                        if let Some(name) = &tool_call.name {
399                                            call.name.push_str(name);
400                                        }
401                                        if let Some(arguments) = &tool_call.arguments {
402                                            call.arguments.push_str(arguments);
403                                        }
404                                    }
405                                    None => {}
406                                }
407                            }
408
409                            *message_body =
410                                RichText::new(body.clone(), &[], &this.language_registry);
411                            cx.notify();
412                        } else {
413                            unreachable!()
414                        }
415                    })?;
416                }
417
418                anyhow::Ok(())
419            }
420            .await;
421
422            let mut tool_tasks = Vec::new();
423            this.update(cx, |this, cx| {
424                if let Some(ChatMessage::Assistant(AssistantMessage {
425                    error: message_error,
426                    tool_calls,
427                    ..
428                })) = this.messages.last_mut()
429                {
430                    if let Err(error) = complete {
431                        message_error.replace(SharedString::from(error.to_string()));
432                        cx.notify();
433                    } else {
434                        for tool_call in tool_calls.iter() {
435                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
436                        }
437                    }
438                }
439            })?;
440
441            if tool_tasks.is_empty() {
442                return Ok(());
443            }
444
445            let tools = join_all(tool_tasks.into_iter()).await;
446            // If the WindowContext went away for any tool's view we don't include it
447            // especially since the below call would fail for the same reason.
448            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
449
450            this.update(cx, |this, cx| {
451                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
452                    this.messages.last_mut()
453                {
454                    *tool_calls = tools;
455                    cx.notify();
456                }
457            })?;
458        }
459    }
460
461    fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
462        let message = ChatMessage::Assistant(AssistantMessage {
463            id: self.next_message_id.post_inc(),
464            body: RichText::default(),
465            tool_calls: Vec::new(),
466            error: None,
467        });
468        self.push_message(message, cx);
469    }
470
471    fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
472        let old_len = self.messages.len();
473        let focus_handle = Some(message.focus_handle(cx));
474        self.messages.push(message);
475        self.list_state
476            .splice_focusable(old_len..old_len, focus_handle);
477        cx.notify();
478    }
479
480    fn pop_message(&mut self, cx: &mut ViewContext<Self>) {
481        if self.messages.is_empty() {
482            return;
483        }
484
485        self.messages.pop();
486        self.list_state
487            .splice(self.messages.len()..self.messages.len() + 1, 0);
488        cx.notify();
489    }
490
491    fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
492        if let Some(index) = self.messages.iter().position(|message| match message {
493            ChatMessage::User(message) => message.id == last_message_id,
494            ChatMessage::Assistant(message) => message.id == last_message_id,
495        }) {
496            self.list_state.splice(index + 1..self.messages.len(), 0);
497            self.messages.truncate(index + 1);
498            cx.notify();
499        }
500    }
501
502    fn is_message_collapsed(&self, id: &MessageId) -> bool {
503        self.collapsed_messages.get(id).copied().unwrap_or_default()
504    }
505
506    fn toggle_message_collapsed(&mut self, id: MessageId) {
507        let entry = self.collapsed_messages.entry(id).or_insert(false);
508        *entry = !*entry;
509    }
510
511    fn render_error(
512        &self,
513        error: Option<SharedString>,
514        _ix: usize,
515        cx: &mut ViewContext<Self>,
516    ) -> AnyElement {
517        let theme = cx.theme();
518
519        if let Some(error) = error {
520            div()
521                .py_1()
522                .px_2()
523                .neg_mx_1()
524                .rounded_md()
525                .border()
526                .border_color(theme.status().error_border)
527                // .bg(theme.status().error_background)
528                .text_color(theme.status().error)
529                .child(error.clone())
530                .into_any_element()
531        } else {
532            div().into_any_element()
533        }
534    }
535
536    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
537        let is_last = ix == self.messages.len() - 1;
538
539        match &self.messages[ix] {
540            ChatMessage::User(UserMessage { id, body }) => div()
541                .when(!is_last, |element| element.mb_2())
542                .child(crate::ui::ChatMessage::new(
543                    *id,
544                    UserOrAssistant::User(self.user_store.read(cx).current_user()),
545                    body.clone().into_any_element(),
546                    self.is_message_collapsed(id),
547                    Box::new(cx.listener({
548                        let id = *id;
549                        move |assistant_chat, _event, _cx| {
550                            assistant_chat.toggle_message_collapsed(id)
551                        }
552                    })),
553                ))
554                .into_any(),
555            ChatMessage::Assistant(AssistantMessage {
556                id,
557                body,
558                error,
559                tool_calls,
560                ..
561            }) => {
562                let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
563                    div()
564                } else {
565                    div().p_2().child(body.element(ElementId::from(id.0), cx))
566                };
567
568                div()
569                    .when(!is_last, |element| element.mb_2())
570                    .child(crate::ui::ChatMessage::new(
571                        *id,
572                        UserOrAssistant::Assistant,
573                        assistant_body.into_any_element(),
574                        self.is_message_collapsed(id),
575                        Box::new(cx.listener({
576                            let id = *id;
577                            move |assistant_chat, _event, _cx| {
578                                assistant_chat.toggle_message_collapsed(id)
579                            }
580                        })),
581                    ))
582                    // TODO: Should the errors and tool calls get passed into `ChatMessage`?
583                    .child(self.render_error(error.clone(), ix, cx))
584                    .children(tool_calls.iter().map(|tool_call| {
585                        let result = &tool_call.result;
586                        let name = tool_call.name.clone();
587                        match result {
588                            Some(result) => {
589                                div().p_2().child(result.into_any_element(&name)).into_any()
590                            }
591                            None => div()
592                                .p_2()
593                                .child(Label::new(name).color(Color::Modified))
594                                .child("Running...")
595                                .into_any(),
596                        }
597                    }))
598                    .into_any()
599            }
600        }
601    }
602
603    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
604        let mut completion_messages = Vec::new();
605
606        for message in &self.messages {
607            match message {
608                ChatMessage::User(UserMessage { body, .. }) => {
609                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
610                    // contexts.iter().for_each(|context| {
611                    //     completion_messages.extend(context.completion_messages(cx))
612                    // });
613
614                    // Show user's message last so that the assistant is grounded in the user's request
615                    completion_messages.push(CompletionMessage::User {
616                        content: body.read(cx).text(cx),
617                    });
618                }
619                ChatMessage::Assistant(AssistantMessage {
620                    body, tool_calls, ..
621                }) => {
622                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
623                    // not break the Chat API if it does.
624                    if body.text.is_empty() && tool_calls.is_empty() {
625                        continue;
626                    }
627
628                    let tool_calls_from_assistant = tool_calls
629                        .iter()
630                        .map(|tool_call| ToolCall {
631                            content: ToolCallContent::Function {
632                                function: FunctionContent {
633                                    name: tool_call.name.clone(),
634                                    arguments: tool_call.arguments.clone(),
635                                },
636                            },
637                            id: tool_call.id.clone(),
638                        })
639                        .collect();
640
641                    completion_messages.push(CompletionMessage::Assistant {
642                        content: Some(body.text.to_string()),
643                        tool_calls: tool_calls_from_assistant,
644                    });
645
646                    for tool_call in tool_calls {
647                        // todo!(): we should not be sending when the tool is still running / has no result
648                        // For now I'm going to have to assume we send an empty string because otherwise
649                        // the Chat API will break -- there is a required message for every tool call by ID
650                        let content = match &tool_call.result {
651                            Some(result) => result.format(&tool_call.name),
652                            None => "".to_string(),
653                        };
654
655                        completion_messages.push(CompletionMessage::Tool {
656                            content,
657                            tool_call_id: tool_call.id.clone(),
658                        });
659                    }
660                }
661            }
662        }
663
664        completion_messages
665    }
666}
667
668impl Render for AssistantChat {
669    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
670        div()
671            .relative()
672            .flex_1()
673            .v_flex()
674            .key_context("AssistantChat")
675            .on_action(cx.listener(Self::submit))
676            .on_action(cx.listener(Self::cancel))
677            .text_color(Color::Default.color(cx))
678            .child(list(self.list_state.clone()).flex_1())
679            .child(Composer::new(
680                cx.view().downgrade(),
681                self.model.clone(),
682                self.composer_editor.clone(),
683                self.user_store.read(cx).current_user(),
684                self.can_submit(),
685                self.tool_registry.clone(),
686            ))
687    }
688}
689
690#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
691struct MessageId(usize);
692
693impl MessageId {
694    fn post_inc(&mut self) -> Self {
695        let id = *self;
696        self.0 += 1;
697        id
698    }
699}
700
701enum ChatMessage {
702    User(UserMessage),
703    Assistant(AssistantMessage),
704}
705
706impl ChatMessage {
707    fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
708        match self {
709            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
710            ChatMessage::Assistant(_) => None,
711        }
712    }
713}
714
715struct UserMessage {
716    id: MessageId,
717    body: View<Editor>,
718}
719
720struct AssistantMessage {
721    id: MessageId,
722    body: RichText,
723    tool_calls: Vec<ToolFunctionCall>,
724    error: Option<SharedString>,
725}