assistant2.rs

  1mod assistant_settings;
  2mod completion_provider;
  3pub mod tools;
  4mod ui;
  5
  6use ::ui::{div, prelude::*, Color, ViewContext};
  7use anyhow::{Context, Result};
  8use assistant_tooling::{ToolFunctionCall, ToolRegistry};
  9use client::{proto, Client, UserStore};
 10use completion_provider::*;
 11use editor::Editor;
 12use feature_flags::FeatureFlagAppExt as _;
 13use futures::{future::join_all, StreamExt};
 14use gpui::{
 15    list, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, FocusableView,
 16    ListAlignment, ListState, Model, Render, Task, View, WeakView,
 17};
 18use language::{language_settings::SoftWrap, LanguageRegistry};
 19use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 20use rich_text::RichText;
 21use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 22use serde::Deserialize;
 23use settings::Settings;
 24use std::sync::Arc;
 25use tools::ProjectIndexTool;
 26use ui::Composer;
 27use util::{paths::EMBEDDINGS_DIR, ResultExt};
 28use workspace::{
 29    dock::{DockPosition, Panel, PanelEvent},
 30    Workspace,
 31};
 32
 33pub use assistant_settings::AssistantSettings;
 34
 35use crate::ui::UserOrAssistant;
 36
 37const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 38
 39#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 40pub struct Submit(SubmitMode);
 41
 42/// There are multiple different ways to submit a model request, represented by this enum.
 43#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 44pub enum SubmitMode {
 45    /// Only include the conversation.
 46    Simple,
 47    /// Send the current file as context.
 48    CurrentFile,
 49    /// Search the codebase and send relevant excerpts.
 50    Codebase,
 51}
 52
 53gpui::actions!(assistant2, [Cancel, ToggleFocus]);
 54gpui::impl_actions!(assistant2, [Submit]);
 55
 56pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 57    AssistantSettings::register(cx);
 58
 59    cx.spawn(|mut cx| {
 60        let client = client.clone();
 61        async move {
 62            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
 63            let semantic_index = SemanticIndex::new(
 64                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
 65                Arc::new(embedding_provider),
 66                &mut cx,
 67            )
 68            .await?;
 69            cx.update(|cx| cx.set_global(semantic_index))
 70        }
 71    })
 72    .detach();
 73
 74    cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
 75        client,
 76    )));
 77
 78    cx.observe_new_views(
 79        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 80            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 81                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 82            });
 83        },
 84    )
 85    .detach();
 86}
 87
 88pub fn enabled(cx: &AppContext) -> bool {
 89    cx.is_staff()
 90}
 91
 92pub struct AssistantPanel {
 93    chat: View<AssistantChat>,
 94    width: Option<Pixels>,
 95}
 96
 97impl AssistantPanel {
 98    pub fn load(
 99        workspace: WeakView<Workspace>,
100        cx: AsyncWindowContext,
101    ) -> Task<Result<View<Self>>> {
102        cx.spawn(|mut cx| async move {
103            let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
104                (workspace.app_state().clone(), workspace.project().clone())
105            })?;
106
107            let user_store = app_state.user_store.clone();
108
109            cx.new_view(|cx| {
110                // todo!("this will panic if the semantic index failed to load or has not loaded yet")
111                let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
112                    semantic_index.project_index(project.clone(), cx)
113                });
114
115                let mut tool_registry = ToolRegistry::new();
116                tool_registry
117                    .register(
118                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
119                        cx,
120                    )
121                    .context("failed to register ProjectIndexTool")
122                    .log_err();
123
124                let tool_registry = Arc::new(tool_registry);
125
126                Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
127            })
128        })
129    }
130
131    pub fn new(
132        language_registry: Arc<LanguageRegistry>,
133        tool_registry: Arc<ToolRegistry>,
134        user_store: Model<UserStore>,
135        cx: &mut ViewContext<Self>,
136    ) -> Self {
137        let chat = cx.new_view(|cx| {
138            AssistantChat::new(
139                language_registry.clone(),
140                tool_registry.clone(),
141                user_store,
142                cx,
143            )
144        });
145
146        Self { width: None, chat }
147    }
148}
149
150impl Render for AssistantPanel {
151    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
152        div()
153            .size_full()
154            .v_flex()
155            .p_2()
156            .bg(cx.theme().colors().background)
157            .child(self.chat.clone())
158    }
159}
160
161impl Panel for AssistantPanel {
162    fn persistent_name() -> &'static str {
163        "AssistantPanelv2"
164    }
165
166    fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
167        // todo!("Add a setting / use assistant settings")
168        DockPosition::Right
169    }
170
171    fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
172        matches!(position, DockPosition::Right)
173    }
174
175    fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
176        // Do nothing until we have a setting for this
177    }
178
179    fn size(&self, _cx: &WindowContext) -> Pixels {
180        self.width.unwrap_or(px(400.))
181    }
182
183    fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
184        self.width = size;
185        cx.notify();
186    }
187
188    fn icon(&self, _cx: &WindowContext) -> Option<::ui::IconName> {
189        Some(IconName::Ai)
190    }
191
192    fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
193        Some("Assistant Panel ✨")
194    }
195
196    fn toggle_action(&self) -> Box<dyn gpui::Action> {
197        Box::new(ToggleFocus)
198    }
199}
200
201impl EventEmitter<PanelEvent> for AssistantPanel {}
202
203impl FocusableView for AssistantPanel {
204    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
205        self.chat.read(cx).composer_editor.read(cx).focus_handle(cx)
206    }
207}
208
209struct AssistantChat {
210    model: String,
211    messages: Vec<ChatMessage>,
212    list_state: ListState,
213    language_registry: Arc<LanguageRegistry>,
214    composer_editor: View<Editor>,
215    user_store: Model<UserStore>,
216    next_message_id: MessageId,
217    pending_completion: Option<Task<()>>,
218    tool_registry: Arc<ToolRegistry>,
219}
220
221impl AssistantChat {
222    fn new(
223        language_registry: Arc<LanguageRegistry>,
224        tool_registry: Arc<ToolRegistry>,
225        user_store: Model<UserStore>,
226        cx: &mut ViewContext<Self>,
227    ) -> Self {
228        let model = CompletionProvider::get(cx).default_model();
229        let view = cx.view().downgrade();
230        let list_state = ListState::new(
231            0,
232            ListAlignment::Bottom,
233            px(1024.),
234            move |ix, cx: &mut WindowContext| {
235                view.update(cx, |this, cx| this.render_message(ix, cx))
236                    .unwrap()
237            },
238        );
239
240        Self {
241            model,
242            messages: Vec::new(),
243            composer_editor: cx.new_view(|cx| {
244                let mut editor = Editor::auto_height(80, cx);
245                editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
246                editor.set_placeholder_text("Type a message to the assistant", cx);
247                editor
248            }),
249            list_state,
250            user_store,
251            language_registry,
252            next_message_id: MessageId(0),
253            pending_completion: None,
254            tool_registry,
255        }
256    }
257
258    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
259        self.messages.iter().find_map(|message| match message {
260            ChatMessage::User(message) => message
261                .body
262                .focus_handle(cx)
263                .contains_focused(cx)
264                .then_some(message.id),
265            ChatMessage::Assistant(_) => None,
266        })
267    }
268
269    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
270        if self.pending_completion.take().is_none() {
271            cx.propagate();
272            return;
273        }
274
275        if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
276            if message.body.text.is_empty() {
277                self.pop_message(cx);
278            }
279        }
280    }
281
282    fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
283        // Don't allow multiple concurrent completions.
284        if self.pending_completion.is_some() {
285            cx.propagate();
286            return;
287        }
288
289        if let Some(focused_message_id) = self.focused_message_id(cx) {
290            self.truncate_messages(focused_message_id, cx);
291        } else if self.composer_editor.focus_handle(cx).is_focused(cx) {
292            let message = self.composer_editor.update(cx, |composer_editor, cx| {
293                let text = composer_editor.text(cx);
294                let id = self.next_message_id.post_inc();
295                let body = cx.new_view(|cx| {
296                    let mut editor = Editor::auto_height(80, cx);
297                    editor.set_text(text, cx);
298                    editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
299                    editor
300                });
301                composer_editor.clear(cx);
302                ChatMessage::User(UserMessage { id, body })
303            });
304            self.push_message(message, cx);
305        } else {
306            log::error!("unexpected state: no user message editor is focused.");
307            return;
308        }
309
310        let mode = *mode;
311        self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
312            Self::request_completion(
313                this.clone(),
314                mode,
315                MAX_COMPLETION_CALLS_PER_SUBMISSION,
316                &mut cx,
317            )
318            .await
319            .log_err();
320
321            this.update(&mut cx, |this, cx| {
322                let composer_focus_handle = this.composer_editor.focus_handle(cx);
323                cx.focus(&composer_focus_handle);
324                this.pending_completion = None;
325            })
326            .context("Failed to push new user message")
327            .log_err();
328        }));
329    }
330
331    fn can_submit(&self) -> bool {
332        self.pending_completion.is_none()
333    }
334
335    async fn request_completion(
336        this: WeakView<Self>,
337        mode: SubmitMode,
338        limit: usize,
339        cx: &mut AsyncWindowContext,
340    ) -> Result<()> {
341        let mut call_count = 0;
342        loop {
343            let complete = async {
344                let completion = this.update(cx, |this, cx| {
345                    this.push_new_assistant_message(cx);
346
347                    let definitions = if call_count < limit
348                        && matches!(mode, SubmitMode::Codebase | SubmitMode::Simple)
349                    {
350                        this.tool_registry.definitions()
351                    } else {
352                        &[]
353                    };
354                    call_count += 1;
355
356                    let messages = this.completion_messages(cx);
357
358                    CompletionProvider::get(cx).complete(
359                        this.model.clone(),
360                        messages,
361                        Vec::new(),
362                        1.0,
363                        definitions,
364                    )
365                });
366
367                let mut stream = completion?.await?;
368                let mut body = String::new();
369                while let Some(delta) = stream.next().await {
370                    let delta = delta?;
371                    this.update(cx, |this, cx| {
372                        if let Some(ChatMessage::Assistant(AssistantMessage {
373                            body: message_body,
374                            tool_calls: message_tool_calls,
375                            ..
376                        })) = this.messages.last_mut()
377                        {
378                            if let Some(content) = &delta.content {
379                                body.push_str(content);
380                            }
381
382                            for tool_call in delta.tool_calls {
383                                let index = tool_call.index as usize;
384                                if index >= message_tool_calls.len() {
385                                    message_tool_calls.resize_with(index + 1, Default::default);
386                                }
387                                let call = &mut message_tool_calls[index];
388
389                                if let Some(id) = &tool_call.id {
390                                    call.id.push_str(id);
391                                }
392
393                                match tool_call.variant {
394                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
395                                        if let Some(name) = &tool_call.name {
396                                            call.name.push_str(name);
397                                        }
398                                        if let Some(arguments) = &tool_call.arguments {
399                                            call.arguments.push_str(arguments);
400                                        }
401                                    }
402                                    None => {}
403                                }
404                            }
405
406                            *message_body =
407                                RichText::new(body.clone(), &[], &this.language_registry);
408                            cx.notify();
409                        } else {
410                            unreachable!()
411                        }
412                    })?;
413                }
414
415                anyhow::Ok(())
416            }
417            .await;
418
419            let mut tool_tasks = Vec::new();
420            this.update(cx, |this, cx| {
421                if let Some(ChatMessage::Assistant(AssistantMessage {
422                    error: message_error,
423                    tool_calls,
424                    ..
425                })) = this.messages.last_mut()
426                {
427                    if let Err(error) = complete {
428                        message_error.replace(SharedString::from(error.to_string()));
429                        cx.notify();
430                    } else {
431                        for tool_call in tool_calls.iter() {
432                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
433                        }
434                    }
435                }
436            })?;
437
438            if tool_tasks.is_empty() {
439                return Ok(());
440            }
441
442            let tools = join_all(tool_tasks.into_iter()).await;
443            // If the WindowContext went away for any tool's view we don't include it
444            // especially since the below call would fail for the same reason.
445            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
446
447            this.update(cx, |this, cx| {
448                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
449                    this.messages.last_mut()
450                {
451                    *tool_calls = tools;
452                    cx.notify();
453                }
454            })?;
455        }
456    }
457
458    fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
459        let message = ChatMessage::Assistant(AssistantMessage {
460            id: self.next_message_id.post_inc(),
461            body: RichText::default(),
462            tool_calls: Vec::new(),
463            error: None,
464        });
465        self.push_message(message, cx);
466    }
467
468    fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
469        let old_len = self.messages.len();
470        let focus_handle = Some(message.focus_handle(cx));
471        self.messages.push(message);
472        self.list_state
473            .splice_focusable(old_len..old_len, focus_handle);
474        cx.notify();
475    }
476
477    fn pop_message(&mut self, cx: &mut ViewContext<Self>) {
478        if self.messages.is_empty() {
479            return;
480        }
481
482        self.messages.pop();
483        self.list_state
484            .splice(self.messages.len()..self.messages.len() + 1, 0);
485        cx.notify();
486    }
487
488    fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
489        if let Some(index) = self.messages.iter().position(|message| match message {
490            ChatMessage::User(message) => message.id == last_message_id,
491            ChatMessage::Assistant(message) => message.id == last_message_id,
492        }) {
493            self.list_state.splice(index + 1..self.messages.len(), 0);
494            self.messages.truncate(index + 1);
495            cx.notify();
496        }
497    }
498
499    fn render_error(
500        &self,
501        error: Option<SharedString>,
502        _ix: usize,
503        cx: &mut ViewContext<Self>,
504    ) -> AnyElement {
505        let theme = cx.theme();
506
507        if let Some(error) = error {
508            div()
509                .py_1()
510                .px_2()
511                .neg_mx_1()
512                .rounded_md()
513                .border()
514                .border_color(theme.status().error_border)
515                // .bg(theme.status().error_background)
516                .text_color(theme.status().error)
517                .child(error.clone())
518                .into_any_element()
519        } else {
520            div().into_any_element()
521        }
522    }
523
524    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
525        let is_last = ix == self.messages.len() - 1;
526
527        match &self.messages[ix] {
528            ChatMessage::User(UserMessage { id, body }) => div()
529                .when(!is_last, |element| element.mb_2())
530                .child(crate::ui::ChatMessage::new(
531                    *id,
532                    UserOrAssistant::User(self.user_store.read(cx).current_user()),
533                    body.clone().into_any_element(),
534                    false,
535                    Box::new(|_, _| {}),
536                ))
537                .into_any(),
538            ChatMessage::Assistant(AssistantMessage {
539                id,
540                body,
541                error,
542                tool_calls,
543                ..
544            }) => {
545                let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
546                    div()
547                } else {
548                    div().p_2().child(body.element(ElementId::from(id.0), cx))
549                };
550
551                div()
552                    .when(!is_last, |element| element.mb_2())
553                    .child(crate::ui::ChatMessage::new(
554                        *id,
555                        UserOrAssistant::Assistant,
556                        assistant_body.into_any_element(),
557                        false,
558                        Box::new(|_, _| {}),
559                    ))
560                    // TODO: Should the errors and tool calls get passed into `ChatMessage`?
561                    .child(self.render_error(error.clone(), ix, cx))
562                    .children(tool_calls.iter().map(|tool_call| {
563                        let result = &tool_call.result;
564                        let name = tool_call.name.clone();
565                        match result {
566                            Some(result) => {
567                                div().p_2().child(result.into_any_element(&name)).into_any()
568                            }
569                            None => div()
570                                .p_2()
571                                .child(Label::new(name).color(Color::Modified))
572                                .child("Running...")
573                                .into_any(),
574                        }
575                    }))
576                    .into_any()
577            }
578        }
579    }
580
581    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
582        let mut completion_messages = Vec::new();
583
584        for message in &self.messages {
585            match message {
586                ChatMessage::User(UserMessage { body, .. }) => {
587                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
588                    // contexts.iter().for_each(|context| {
589                    //     completion_messages.extend(context.completion_messages(cx))
590                    // });
591
592                    // Show user's message last so that the assistant is grounded in the user's request
593                    completion_messages.push(CompletionMessage::User {
594                        content: body.read(cx).text(cx),
595                    });
596                }
597                ChatMessage::Assistant(AssistantMessage {
598                    body, tool_calls, ..
599                }) => {
600                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
601                    // not break the Chat API if it does.
602                    if body.text.is_empty() && tool_calls.is_empty() {
603                        continue;
604                    }
605
606                    let tool_calls_from_assistant = tool_calls
607                        .iter()
608                        .map(|tool_call| ToolCall {
609                            content: ToolCallContent::Function {
610                                function: FunctionContent {
611                                    name: tool_call.name.clone(),
612                                    arguments: tool_call.arguments.clone(),
613                                },
614                            },
615                            id: tool_call.id.clone(),
616                        })
617                        .collect();
618
619                    completion_messages.push(CompletionMessage::Assistant {
620                        content: Some(body.text.to_string()),
621                        tool_calls: tool_calls_from_assistant,
622                    });
623
624                    for tool_call in tool_calls {
625                        // todo!(): we should not be sending when the tool is still running / has no result
626                        // For now I'm going to have to assume we send an empty string because otherwise
627                        // the Chat API will break -- there is a required message for every tool call by ID
628                        let content = match &tool_call.result {
629                            Some(result) => result.format(&tool_call.name),
630                            None => "".to_string(),
631                        };
632
633                        completion_messages.push(CompletionMessage::Tool {
634                            content,
635                            tool_call_id: tool_call.id.clone(),
636                        });
637                    }
638                }
639            }
640        }
641
642        completion_messages
643    }
644}
645
646impl Render for AssistantChat {
647    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
648        div()
649            .relative()
650            .flex_1()
651            .v_flex()
652            .key_context("AssistantChat")
653            .on_action(cx.listener(Self::submit))
654            .on_action(cx.listener(Self::cancel))
655            .text_color(Color::Default.color(cx))
656            .child(list(self.list_state.clone()).flex_1())
657            .child(Composer::new(
658                cx.view().downgrade(),
659                self.model.clone(),
660                self.composer_editor.clone(),
661                self.user_store.read(cx).current_user(),
662                self.can_submit(),
663                self.tool_registry.clone(),
664            ))
665    }
666}
667
668#[derive(Debug, Copy, Clone, Eq, PartialEq)]
669struct MessageId(usize);
670
671impl MessageId {
672    fn post_inc(&mut self) -> Self {
673        let id = *self;
674        self.0 += 1;
675        id
676    }
677}
678
679enum ChatMessage {
680    User(UserMessage),
681    Assistant(AssistantMessage),
682}
683
684impl ChatMessage {
685    fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
686        match self {
687            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
688            ChatMessage::Assistant(_) => None,
689        }
690    }
691}
692
693struct UserMessage {
694    id: MessageId,
695    body: View<Editor>,
696}
697
698struct AssistantMessage {
699    id: MessageId,
700    body: RichText,
701    tool_calls: Vec<ToolFunctionCall>,
702    error: Option<SharedString>,
703}