assistant2.rs

  1mod assistant_settings;
  2mod completion_provider;
  3pub mod tools;
  4
  5use anyhow::{Context, Result};
  6use assistant_tooling::{ToolFunctionCall, ToolRegistry};
  7use client::{proto, Client};
  8use completion_provider::*;
  9use editor::Editor;
 10use feature_flags::FeatureFlagAppExt as _;
 11use futures::{future::join_all, StreamExt};
 12use gpui::{
 13    list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
 14    FocusableView, Global, ListAlignment, ListState, Render, Task, View, WeakView,
 15};
 16use language::{language_settings::SoftWrap, LanguageRegistry};
 17use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 18use rich_text::RichText;
 19use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 20use serde::Deserialize;
 21use settings::Settings;
 22use std::sync::Arc;
 23use theme::ThemeSettings;
 24use tools::ProjectIndexTool;
 25use ui::{popover_menu, prelude::*, ButtonLike, Color, ContextMenu, Tooltip};
 26use util::{paths::EMBEDDINGS_DIR, ResultExt};
 27use workspace::{
 28    dock::{DockPosition, Panel, PanelEvent},
 29    Workspace,
 30};
 31
 32pub use assistant_settings::AssistantSettings;
 33
 34const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 35
 36#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 37pub struct Submit(SubmitMode);
 38
 39/// There are multiple different ways to submit a model request, represented by this enum.
 40#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 41pub enum SubmitMode {
 42    /// Only include the conversation.
 43    Simple,
 44    /// Send the current file as context.
 45    CurrentFile,
 46    /// Search the codebase and send relevant excerpts.
 47    Codebase,
 48}
 49
 50gpui::actions!(assistant2, [Cancel, ToggleFocus]);
 51gpui::impl_actions!(assistant2, [Submit]);
 52
 53pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 54    AssistantSettings::register(cx);
 55
 56    cx.spawn(|mut cx| {
 57        let client = client.clone();
 58        async move {
 59            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
 60            let semantic_index = SemanticIndex::new(
 61                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
 62                Arc::new(embedding_provider),
 63                &mut cx,
 64            )
 65            .await?;
 66            cx.update(|cx| cx.set_global(semantic_index))
 67        }
 68    })
 69    .detach();
 70
 71    cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
 72        client,
 73    )));
 74
 75    cx.observe_new_views(
 76        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 77            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 78                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 79            });
 80        },
 81    )
 82    .detach();
 83}
 84
 85pub fn enabled(cx: &AppContext) -> bool {
 86    cx.is_staff()
 87}
 88
 89pub struct AssistantPanel {
 90    chat: View<AssistantChat>,
 91    width: Option<Pixels>,
 92}
 93
 94impl AssistantPanel {
 95    pub fn load(
 96        workspace: WeakView<Workspace>,
 97        cx: AsyncWindowContext,
 98    ) -> Task<Result<View<Self>>> {
 99        cx.spawn(|mut cx| async move {
100            let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
101                (workspace.app_state().clone(), workspace.project().clone())
102            })?;
103
104            cx.new_view(|cx| {
105                // todo!("this will panic if the semantic index failed to load or has not loaded yet")
106                let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
107                    semantic_index.project_index(project.clone(), cx)
108                });
109
110                let mut tool_registry = ToolRegistry::new();
111                tool_registry
112                    .register(
113                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
114                        cx,
115                    )
116                    .context("failed to register ProjectIndexTool")
117                    .log_err();
118
119                let tool_registry = Arc::new(tool_registry);
120
121                Self::new(app_state.languages.clone(), tool_registry, cx)
122            })
123        })
124    }
125
126    pub fn new(
127        language_registry: Arc<LanguageRegistry>,
128        tool_registry: Arc<ToolRegistry>,
129        cx: &mut ViewContext<Self>,
130    ) -> Self {
131        let chat = cx.new_view(|cx| {
132            AssistantChat::new(language_registry.clone(), tool_registry.clone(), cx)
133        });
134
135        Self { width: None, chat }
136    }
137}
138
139impl Render for AssistantPanel {
140    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
141        div()
142            .size_full()
143            .v_flex()
144            .p_2()
145            .bg(cx.theme().colors().background)
146            .child(self.chat.clone())
147    }
148}
149
150impl Panel for AssistantPanel {
151    fn persistent_name() -> &'static str {
152        "AssistantPanelv2"
153    }
154
155    fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
156        // todo!("Add a setting / use assistant settings")
157        DockPosition::Right
158    }
159
160    fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
161        matches!(position, DockPosition::Right)
162    }
163
164    fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
165        // Do nothing until we have a setting for this
166    }
167
168    fn size(&self, _cx: &WindowContext) -> Pixels {
169        self.width.unwrap_or(px(400.))
170    }
171
172    fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
173        self.width = size;
174        cx.notify();
175    }
176
177    fn icon(&self, _cx: &WindowContext) -> Option<ui::IconName> {
178        Some(IconName::Ai)
179    }
180
181    fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
182        Some("Assistant Panel ✨")
183    }
184
185    fn toggle_action(&self) -> Box<dyn gpui::Action> {
186        Box::new(ToggleFocus)
187    }
188}
189
190impl EventEmitter<PanelEvent> for AssistantPanel {}
191
192impl FocusableView for AssistantPanel {
193    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
194        self.chat
195            .read(cx)
196            .messages
197            .iter()
198            .rev()
199            .find_map(|msg| msg.focus_handle(cx))
200            .expect("no user message in chat")
201    }
202}
203
204struct AssistantChat {
205    model: String,
206    messages: Vec<ChatMessage>,
207    list_state: ListState,
208    language_registry: Arc<LanguageRegistry>,
209    next_message_id: MessageId,
210    pending_completion: Option<Task<()>>,
211    tool_registry: Arc<ToolRegistry>,
212}
213
214impl AssistantChat {
215    fn new(
216        language_registry: Arc<LanguageRegistry>,
217        tool_registry: Arc<ToolRegistry>,
218        cx: &mut ViewContext<Self>,
219    ) -> Self {
220        let model = CompletionProvider::get(cx).default_model();
221        let view = cx.view().downgrade();
222        let list_state = ListState::new(
223            0,
224            ListAlignment::Bottom,
225            px(1024.),
226            move |ix, cx: &mut WindowContext| {
227                view.update(cx, |this, cx| this.render_message(ix, cx))
228                    .unwrap()
229            },
230        );
231
232        let mut this = Self {
233            model,
234            messages: Vec::new(),
235            list_state,
236            language_registry,
237            next_message_id: MessageId(0),
238            pending_completion: None,
239            tool_registry,
240        };
241        this.push_new_user_message(true, cx);
242        this
243    }
244
245    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
246        self.messages.iter().find_map(|message| match message {
247            ChatMessage::User(message) => message
248                .body
249                .focus_handle(cx)
250                .contains_focused(cx)
251                .then_some(message.id),
252            ChatMessage::Assistant(_) => None,
253        })
254    }
255
256    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
257        if self.pending_completion.take().is_none() {
258            cx.propagate();
259            return;
260        }
261
262        if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
263            if message.body.text.is_empty() {
264                self.pop_message(cx);
265            } else {
266                self.push_new_user_message(false, cx);
267            }
268        }
269    }
270
271    fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
272        let Some(focused_message_id) = self.focused_message_id(cx) else {
273            log::error!("unexpected state: no user message editor is focused.");
274            return;
275        };
276
277        self.truncate_messages(focused_message_id, cx);
278
279        let mode = *mode;
280        self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
281            Self::request_completion(
282                this.clone(),
283                mode,
284                MAX_COMPLETION_CALLS_PER_SUBMISSION,
285                &mut cx,
286            )
287            .await
288            .log_err();
289
290            this.update(&mut cx, |this, cx| {
291                let focus = this
292                    .user_message(focused_message_id)
293                    .body
294                    .focus_handle(cx)
295                    .contains_focused(cx);
296                this.push_new_user_message(focus, cx);
297                this.pending_completion = None;
298            })
299            .context("Failed to push new user message")
300            .log_err();
301        }));
302    }
303
304    async fn request_completion(
305        this: WeakView<Self>,
306        mode: SubmitMode,
307        limit: usize,
308        cx: &mut AsyncWindowContext,
309    ) -> Result<()> {
310        let mut call_count = 0;
311        loop {
312            let complete = async {
313                let completion = this.update(cx, |this, cx| {
314                    this.push_new_assistant_message(cx);
315
316                    let definitions = if call_count < limit
317                        && matches!(mode, SubmitMode::Codebase | SubmitMode::Simple)
318                    {
319                        this.tool_registry.definitions()
320                    } else {
321                        &[]
322                    };
323                    call_count += 1;
324
325                    let messages = this.completion_messages(cx);
326
327                    CompletionProvider::get(cx).complete(
328                        this.model.clone(),
329                        messages,
330                        Vec::new(),
331                        1.0,
332                        definitions,
333                    )
334                });
335
336                let mut stream = completion?.await?;
337                let mut body = String::new();
338                while let Some(delta) = stream.next().await {
339                    let delta = delta?;
340                    this.update(cx, |this, cx| {
341                        if let Some(ChatMessage::Assistant(AssistantMessage {
342                            body: message_body,
343                            tool_calls: message_tool_calls,
344                            ..
345                        })) = this.messages.last_mut()
346                        {
347                            if let Some(content) = &delta.content {
348                                body.push_str(content);
349                            }
350
351                            for tool_call in delta.tool_calls {
352                                let index = tool_call.index as usize;
353                                if index >= message_tool_calls.len() {
354                                    message_tool_calls.resize_with(index + 1, Default::default);
355                                }
356                                let call = &mut message_tool_calls[index];
357
358                                if let Some(id) = &tool_call.id {
359                                    call.id.push_str(id);
360                                }
361
362                                match tool_call.variant {
363                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
364                                        if let Some(name) = &tool_call.name {
365                                            call.name.push_str(name);
366                                        }
367                                        if let Some(arguments) = &tool_call.arguments {
368                                            call.arguments.push_str(arguments);
369                                        }
370                                    }
371                                    None => {}
372                                }
373                            }
374
375                            *message_body =
376                                RichText::new(body.clone(), &[], &this.language_registry);
377                            cx.notify();
378                        } else {
379                            unreachable!()
380                        }
381                    })?;
382                }
383
384                anyhow::Ok(())
385            }
386            .await;
387
388            let mut tool_tasks = Vec::new();
389            this.update(cx, |this, cx| {
390                if let Some(ChatMessage::Assistant(AssistantMessage {
391                    error: message_error,
392                    tool_calls,
393                    ..
394                })) = this.messages.last_mut()
395                {
396                    if let Err(error) = complete {
397                        message_error.replace(SharedString::from(error.to_string()));
398                        cx.notify();
399                    } else {
400                        for tool_call in tool_calls.iter() {
401                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
402                        }
403                    }
404                }
405            })?;
406
407            if tool_tasks.is_empty() {
408                return Ok(());
409            }
410
411            let tools = join_all(tool_tasks.into_iter()).await;
412            // If the WindowContext went away for any tool's view we don't include it
413            // especially since the below call would fail for the same reason.
414            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
415
416            this.update(cx, |this, cx| {
417                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
418                    this.messages.last_mut()
419                {
420                    *tool_calls = tools;
421                    cx.notify();
422                }
423            })?;
424        }
425    }
426
427    fn user_message(&mut self, message_id: MessageId) -> &mut UserMessage {
428        self.messages
429            .iter_mut()
430            .find_map(|message| match message {
431                ChatMessage::User(user_message) if user_message.id == message_id => {
432                    Some(user_message)
433                }
434                _ => None,
435            })
436            .expect("User message not found")
437    }
438
439    fn push_new_user_message(&mut self, focus: bool, cx: &mut ViewContext<Self>) {
440        let id = self.next_message_id.post_inc();
441        let body = cx.new_view(|cx| {
442            let mut editor = Editor::auto_height(80, cx);
443            editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
444            if focus {
445                cx.focus_self();
446            }
447            editor
448        });
449        let message = ChatMessage::User(UserMessage { id, body });
450        self.push_message(message, cx);
451    }
452
453    fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
454        let message = ChatMessage::Assistant(AssistantMessage {
455            id: self.next_message_id.post_inc(),
456            body: RichText::default(),
457            tool_calls: Vec::new(),
458            error: None,
459        });
460        self.push_message(message, cx);
461    }
462
463    fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
464        let old_len = self.messages.len();
465        let focus_handle = Some(message.focus_handle(cx));
466        self.messages.push(message);
467        self.list_state
468            .splice_focusable(old_len..old_len, focus_handle);
469        cx.notify();
470    }
471
472    fn pop_message(&mut self, cx: &mut ViewContext<Self>) {
473        if self.messages.is_empty() {
474            return;
475        }
476
477        self.messages.pop();
478        self.list_state
479            .splice(self.messages.len()..self.messages.len() + 1, 0);
480        cx.notify();
481    }
482
483    fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
484        if let Some(index) = self.messages.iter().position(|message| match message {
485            ChatMessage::User(message) => message.id == last_message_id,
486            ChatMessage::Assistant(message) => message.id == last_message_id,
487        }) {
488            self.list_state.splice(index + 1..self.messages.len(), 0);
489            self.messages.truncate(index + 1);
490            cx.notify();
491        }
492    }
493
494    fn render_error(
495        &self,
496        error: Option<SharedString>,
497        _ix: usize,
498        cx: &mut ViewContext<Self>,
499    ) -> AnyElement {
500        let theme = cx.theme();
501
502        if let Some(error) = error {
503            div()
504                .py_1()
505                .px_2()
506                .neg_mx_1()
507                .rounded_md()
508                .border()
509                .border_color(theme.status().error_border)
510                // .bg(theme.status().error_background)
511                .text_color(theme.status().error)
512                .child(error.clone())
513                .into_any_element()
514        } else {
515            div().into_any_element()
516        }
517    }
518
519    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
520        let is_last = ix == self.messages.len() - 1;
521
522        match &self.messages[ix] {
523            ChatMessage::User(UserMessage { body, .. }) => div()
524                .when(!is_last, |element| element.mb_2())
525                .child(div().p_2().child(Label::new("You").color(Color::Default)))
526                .child(
527                    div()
528                        .on_action(cx.listener(Self::submit))
529                        .p_2()
530                        .text_color(cx.theme().colors().editor_foreground)
531                        .font(ThemeSettings::get_global(cx).buffer_font.clone())
532                        .bg(cx.theme().colors().editor_background)
533                        .child(body.clone()),
534                )
535                .into_any(),
536            ChatMessage::Assistant(AssistantMessage {
537                id,
538                body,
539                error,
540                tool_calls,
541                ..
542            }) => {
543                let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
544                    div()
545                } else {
546                    div().p_2().child(body.element(ElementId::from(id.0), cx))
547                };
548
549                div()
550                    .when(!is_last, |element| element.mb_2())
551                    .child(
552                        div()
553                            .p_2()
554                            .child(Label::new("Assistant").color(Color::Modified)),
555                    )
556                    .child(assistant_body)
557                    .child(self.render_error(error.clone(), ix, cx))
558                    .children(tool_calls.iter().map(|tool_call| {
559                        let result = &tool_call.result;
560                        let name = tool_call.name.clone();
561                        match result {
562                            Some(result) => {
563                                div().p_2().child(result.into_any_element(&name)).into_any()
564                            }
565                            None => div()
566                                .p_2()
567                                .child(Label::new(name).color(Color::Modified))
568                                .child("Running...")
569                                .into_any(),
570                        }
571                    }))
572                    .into_any()
573            }
574        }
575    }
576
577    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
578        let mut completion_messages = Vec::new();
579
580        for message in &self.messages {
581            match message {
582                ChatMessage::User(UserMessage { body, .. }) => {
583                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
584                    // contexts.iter().for_each(|context| {
585                    //     completion_messages.extend(context.completion_messages(cx))
586                    // });
587
588                    // Show user's message last so that the assistant is grounded in the user's request
589                    completion_messages.push(CompletionMessage::User {
590                        content: body.read(cx).text(cx),
591                    });
592                }
593                ChatMessage::Assistant(AssistantMessage {
594                    body, tool_calls, ..
595                }) => {
596                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
597                    // not break the Chat API if it does.
598                    if body.text.is_empty() && tool_calls.is_empty() {
599                        continue;
600                    }
601
602                    let tool_calls_from_assistant = tool_calls
603                        .iter()
604                        .map(|tool_call| ToolCall {
605                            content: ToolCallContent::Function {
606                                function: FunctionContent {
607                                    name: tool_call.name.clone(),
608                                    arguments: tool_call.arguments.clone(),
609                                },
610                            },
611                            id: tool_call.id.clone(),
612                        })
613                        .collect();
614
615                    completion_messages.push(CompletionMessage::Assistant {
616                        content: Some(body.text.to_string()),
617                        tool_calls: tool_calls_from_assistant,
618                    });
619
620                    for tool_call in tool_calls {
621                        // todo!(): we should not be sending when the tool is still running / has no result
622                        // For now I'm going to have to assume we send an empty string because otherwise
623                        // the Chat API will break -- there is a required message for every tool call by ID
624                        let content = match &tool_call.result {
625                            Some(result) => result.format(&tool_call.name),
626                            None => "".to_string(),
627                        };
628
629                        completion_messages.push(CompletionMessage::Tool {
630                            content,
631                            tool_call_id: tool_call.id.clone(),
632                        });
633                    }
634                }
635            }
636        }
637
638        completion_messages
639    }
640
641    fn render_model_dropdown(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
642        let this = cx.view().downgrade();
643        div().h_flex().justify_end().child(
644            div().w_32().child(
645                popover_menu("user-menu")
646                    .menu(move |cx| {
647                        ContextMenu::build(cx, |mut menu, cx| {
648                            for model in CompletionProvider::get(cx).available_models() {
649                                menu = menu.custom_entry(
650                                    {
651                                        let model = model.clone();
652                                        move |_| Label::new(model.clone()).into_any_element()
653                                    },
654                                    {
655                                        let this = this.clone();
656                                        move |cx| {
657                                            _ = this.update(cx, |this, cx| {
658                                                this.model = model.clone();
659                                                cx.notify();
660                                            });
661                                        }
662                                    },
663                                );
664                            }
665                            menu
666                        })
667                        .into()
668                    })
669                    .trigger(
670                        ButtonLike::new("active-model")
671                            .child(
672                                h_flex()
673                                    .w_full()
674                                    .gap_0p5()
675                                    .child(
676                                        div()
677                                            .overflow_x_hidden()
678                                            .flex_grow()
679                                            .whitespace_nowrap()
680                                            .child(Label::new(self.model.clone())),
681                                    )
682                                    .child(div().child(
683                                        Icon::new(IconName::ChevronDown).color(Color::Muted),
684                                    )),
685                            )
686                            .style(ButtonStyle::Subtle)
687                            .tooltip(move |cx| Tooltip::text("Change Model", cx)),
688                    )
689                    .anchor(gpui::AnchorCorner::TopRight),
690            ),
691        )
692    }
693}
694
695impl Render for AssistantChat {
696    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
697        div()
698            .relative()
699            .flex_1()
700            .v_flex()
701            .key_context("AssistantChat")
702            .on_action(cx.listener(Self::cancel))
703            .text_color(Color::Default.color(cx))
704            .child(self.render_model_dropdown(cx))
705            .child(list(self.list_state.clone()).flex_1())
706            .child(
707                h_flex()
708                    .mt_2()
709                    .gap_2()
710                    .children(self.tool_registry.status_views().iter().cloned()),
711            )
712    }
713}
714
715#[derive(Copy, Clone, Eq, PartialEq)]
716struct MessageId(usize);
717
718impl MessageId {
719    fn post_inc(&mut self) -> Self {
720        let id = *self;
721        self.0 += 1;
722        id
723    }
724}
725
726enum ChatMessage {
727    User(UserMessage),
728    Assistant(AssistantMessage),
729}
730
731impl ChatMessage {
732    fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
733        match self {
734            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
735            ChatMessage::Assistant(_) => None,
736        }
737    }
738}
739
740struct UserMessage {
741    id: MessageId,
742    body: View<Editor>,
743}
744
745struct AssistantMessage {
746    id: MessageId,
747    body: RichText,
748    tool_calls: Vec<ToolFunctionCall>,
749    error: Option<SharedString>,
750}