assistant2.rs

  1mod assistant_settings;
  2mod completion_provider;
  3pub mod tools;
  4mod ui;
  5
  6use ::ui::{div, prelude::*, Color, ViewContext};
  7use anyhow::{Context, Result};
  8use assistant_tooling::{ToolFunctionCall, ToolRegistry};
  9use client::{proto, Client, UserStore};
 10use completion_provider::*;
 11use editor::Editor;
 12use feature_flags::FeatureFlagAppExt as _;
 13use futures::{future::join_all, StreamExt};
 14use gpui::{
 15    list, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, FocusableView,
 16    ListAlignment, ListState, Model, Render, Task, View, WeakView,
 17};
 18use language::{language_settings::SoftWrap, LanguageRegistry};
 19use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 20use rich_text::RichText;
 21use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 22use serde::Deserialize;
 23use settings::Settings;
 24use std::sync::Arc;
 25use theme::ThemeSettings;
 26use tools::ProjectIndexTool;
 27use ui::Composer;
 28use util::{paths::EMBEDDINGS_DIR, ResultExt};
 29use workspace::{
 30    dock::{DockPosition, Panel, PanelEvent},
 31    Workspace,
 32};
 33
 34pub use assistant_settings::AssistantSettings;
 35
 36const MAX_COMPLETION_CALLS_PER_SUBMISSION: usize = 5;
 37
 38#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 39pub struct Submit(SubmitMode);
 40
 41/// There are multiple different ways to submit a model request, represented by this enum.
 42#[derive(Eq, PartialEq, Copy, Clone, Deserialize)]
 43pub enum SubmitMode {
 44    /// Only include the conversation.
 45    Simple,
 46    /// Send the current file as context.
 47    CurrentFile,
 48    /// Search the codebase and send relevant excerpts.
 49    Codebase,
 50}
 51
 52gpui::actions!(assistant2, [Cancel, ToggleFocus]);
 53gpui::impl_actions!(assistant2, [Submit]);
 54
 55pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 56    AssistantSettings::register(cx);
 57
 58    cx.spawn(|mut cx| {
 59        let client = client.clone();
 60        async move {
 61            let embedding_provider = CloudEmbeddingProvider::new(client.clone());
 62            let semantic_index = SemanticIndex::new(
 63                EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
 64                Arc::new(embedding_provider),
 65                &mut cx,
 66            )
 67            .await?;
 68            cx.update(|cx| cx.set_global(semantic_index))
 69        }
 70    })
 71    .detach();
 72
 73    cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
 74        client,
 75    )));
 76
 77    cx.observe_new_views(
 78        |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
 79            workspace.register_action(|workspace, _: &ToggleFocus, cx| {
 80                workspace.toggle_panel_focus::<AssistantPanel>(cx);
 81            });
 82        },
 83    )
 84    .detach();
 85}
 86
 87pub fn enabled(cx: &AppContext) -> bool {
 88    cx.is_staff()
 89}
 90
 91pub struct AssistantPanel {
 92    chat: View<AssistantChat>,
 93    width: Option<Pixels>,
 94}
 95
 96impl AssistantPanel {
 97    pub fn load(
 98        workspace: WeakView<Workspace>,
 99        cx: AsyncWindowContext,
100    ) -> Task<Result<View<Self>>> {
101        cx.spawn(|mut cx| async move {
102            let (app_state, project) = workspace.update(&mut cx, |workspace, _| {
103                (workspace.app_state().clone(), workspace.project().clone())
104            })?;
105
106            let user_store = app_state.user_store.clone();
107
108            cx.new_view(|cx| {
109                // todo!("this will panic if the semantic index failed to load or has not loaded yet")
110                let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
111                    semantic_index.project_index(project.clone(), cx)
112                });
113
114                let mut tool_registry = ToolRegistry::new();
115                tool_registry
116                    .register(
117                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
118                        cx,
119                    )
120                    .context("failed to register ProjectIndexTool")
121                    .log_err();
122
123                let tool_registry = Arc::new(tool_registry);
124
125                Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
126            })
127        })
128    }
129
130    pub fn new(
131        language_registry: Arc<LanguageRegistry>,
132        tool_registry: Arc<ToolRegistry>,
133        user_store: Model<UserStore>,
134        cx: &mut ViewContext<Self>,
135    ) -> Self {
136        let chat = cx.new_view(|cx| {
137            AssistantChat::new(
138                language_registry.clone(),
139                tool_registry.clone(),
140                user_store,
141                cx,
142            )
143        });
144
145        Self { width: None, chat }
146    }
147}
148
149impl Render for AssistantPanel {
150    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
151        div()
152            .size_full()
153            .v_flex()
154            .p_2()
155            .bg(cx.theme().colors().background)
156            .child(self.chat.clone())
157    }
158}
159
160impl Panel for AssistantPanel {
161    fn persistent_name() -> &'static str {
162        "AssistantPanelv2"
163    }
164
165    fn position(&self, _cx: &WindowContext) -> workspace::dock::DockPosition {
166        // todo!("Add a setting / use assistant settings")
167        DockPosition::Right
168    }
169
170    fn position_is_valid(&self, position: workspace::dock::DockPosition) -> bool {
171        matches!(position, DockPosition::Right)
172    }
173
174    fn set_position(&mut self, _: workspace::dock::DockPosition, _: &mut ViewContext<Self>) {
175        // Do nothing until we have a setting for this
176    }
177
178    fn size(&self, _cx: &WindowContext) -> Pixels {
179        self.width.unwrap_or(px(400.))
180    }
181
182    fn set_size(&mut self, size: Option<Pixels>, cx: &mut ViewContext<Self>) {
183        self.width = size;
184        cx.notify();
185    }
186
187    fn icon(&self, _cx: &WindowContext) -> Option<::ui::IconName> {
188        Some(IconName::Ai)
189    }
190
191    fn icon_tooltip(&self, _: &WindowContext) -> Option<&'static str> {
192        Some("Assistant Panel ✨")
193    }
194
195    fn toggle_action(&self) -> Box<dyn gpui::Action> {
196        Box::new(ToggleFocus)
197    }
198}
199
200impl EventEmitter<PanelEvent> for AssistantPanel {}
201
202impl FocusableView for AssistantPanel {
203    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
204        self.chat.read(cx).composer_editor.read(cx).focus_handle(cx)
205    }
206}
207
208struct AssistantChat {
209    model: String,
210    messages: Vec<ChatMessage>,
211    list_state: ListState,
212    language_registry: Arc<LanguageRegistry>,
213    composer_editor: View<Editor>,
214    user_store: Model<UserStore>,
215    next_message_id: MessageId,
216    pending_completion: Option<Task<()>>,
217    tool_registry: Arc<ToolRegistry>,
218}
219
220impl AssistantChat {
221    fn new(
222        language_registry: Arc<LanguageRegistry>,
223        tool_registry: Arc<ToolRegistry>,
224        user_store: Model<UserStore>,
225        cx: &mut ViewContext<Self>,
226    ) -> Self {
227        let model = CompletionProvider::get(cx).default_model();
228        let view = cx.view().downgrade();
229        let list_state = ListState::new(
230            0,
231            ListAlignment::Bottom,
232            px(1024.),
233            move |ix, cx: &mut WindowContext| {
234                view.update(cx, |this, cx| this.render_message(ix, cx))
235                    .unwrap()
236            },
237        );
238
239        Self {
240            model,
241            messages: Vec::new(),
242            composer_editor: cx.new_view(|cx| {
243                let mut editor = Editor::auto_height(80, cx);
244                editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
245                editor.set_placeholder_text("Type a message to the assistant", cx);
246                editor
247            }),
248            list_state,
249            user_store,
250            language_registry,
251            next_message_id: MessageId(0),
252            pending_completion: None,
253            tool_registry,
254        }
255    }
256
257    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
258        self.messages.iter().find_map(|message| match message {
259            ChatMessage::User(message) => message
260                .body
261                .focus_handle(cx)
262                .contains_focused(cx)
263                .then_some(message.id),
264            ChatMessage::Assistant(_) => None,
265        })
266    }
267
268    fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
269        if self.pending_completion.take().is_none() {
270            cx.propagate();
271            return;
272        }
273
274        if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
275            if message.body.text.is_empty() {
276                self.pop_message(cx);
277            }
278        }
279    }
280
281    fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
282        // Don't allow multiple concurrent completions.
283        if self.pending_completion.is_some() {
284            cx.propagate();
285            return;
286        }
287
288        if let Some(focused_message_id) = self.focused_message_id(cx) {
289            self.truncate_messages(focused_message_id, cx);
290        } else if self.composer_editor.focus_handle(cx).is_focused(cx) {
291            let message = self.composer_editor.update(cx, |composer_editor, cx| {
292                let text = composer_editor.text(cx);
293                let id = self.next_message_id.post_inc();
294                let body = cx.new_view(|cx| {
295                    let mut editor = Editor::auto_height(80, cx);
296                    editor.set_text(text, cx);
297                    editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
298                    editor
299                });
300                composer_editor.clear(cx);
301                ChatMessage::User(UserMessage { id, body })
302            });
303            self.push_message(message, cx);
304        } else {
305            log::error!("unexpected state: no user message editor is focused.");
306            return;
307        }
308
309        let mode = *mode;
310        self.pending_completion = Some(cx.spawn(move |this, mut cx| async move {
311            Self::request_completion(
312                this.clone(),
313                mode,
314                MAX_COMPLETION_CALLS_PER_SUBMISSION,
315                &mut cx,
316            )
317            .await
318            .log_err();
319
320            this.update(&mut cx, |this, cx| {
321                let composer_focus_handle = this.composer_editor.focus_handle(cx);
322                cx.focus(&composer_focus_handle);
323                this.pending_completion = None;
324            })
325            .context("Failed to push new user message")
326            .log_err();
327        }));
328    }
329
330    fn can_submit(&self) -> bool {
331        self.pending_completion.is_none()
332    }
333
334    async fn request_completion(
335        this: WeakView<Self>,
336        mode: SubmitMode,
337        limit: usize,
338        cx: &mut AsyncWindowContext,
339    ) -> Result<()> {
340        let mut call_count = 0;
341        loop {
342            let complete = async {
343                let completion = this.update(cx, |this, cx| {
344                    this.push_new_assistant_message(cx);
345
346                    let definitions = if call_count < limit
347                        && matches!(mode, SubmitMode::Codebase | SubmitMode::Simple)
348                    {
349                        this.tool_registry.definitions()
350                    } else {
351                        &[]
352                    };
353                    call_count += 1;
354
355                    let messages = this.completion_messages(cx);
356
357                    CompletionProvider::get(cx).complete(
358                        this.model.clone(),
359                        messages,
360                        Vec::new(),
361                        1.0,
362                        definitions,
363                    )
364                });
365
366                let mut stream = completion?.await?;
367                let mut body = String::new();
368                while let Some(delta) = stream.next().await {
369                    let delta = delta?;
370                    this.update(cx, |this, cx| {
371                        if let Some(ChatMessage::Assistant(AssistantMessage {
372                            body: message_body,
373                            tool_calls: message_tool_calls,
374                            ..
375                        })) = this.messages.last_mut()
376                        {
377                            if let Some(content) = &delta.content {
378                                body.push_str(content);
379                            }
380
381                            for tool_call in delta.tool_calls {
382                                let index = tool_call.index as usize;
383                                if index >= message_tool_calls.len() {
384                                    message_tool_calls.resize_with(index + 1, Default::default);
385                                }
386                                let call = &mut message_tool_calls[index];
387
388                                if let Some(id) = &tool_call.id {
389                                    call.id.push_str(id);
390                                }
391
392                                match tool_call.variant {
393                                    Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
394                                        if let Some(name) = &tool_call.name {
395                                            call.name.push_str(name);
396                                        }
397                                        if let Some(arguments) = &tool_call.arguments {
398                                            call.arguments.push_str(arguments);
399                                        }
400                                    }
401                                    None => {}
402                                }
403                            }
404
405                            *message_body =
406                                RichText::new(body.clone(), &[], &this.language_registry);
407                            cx.notify();
408                        } else {
409                            unreachable!()
410                        }
411                    })?;
412                }
413
414                anyhow::Ok(())
415            }
416            .await;
417
418            let mut tool_tasks = Vec::new();
419            this.update(cx, |this, cx| {
420                if let Some(ChatMessage::Assistant(AssistantMessage {
421                    error: message_error,
422                    tool_calls,
423                    ..
424                })) = this.messages.last_mut()
425                {
426                    if let Err(error) = complete {
427                        message_error.replace(SharedString::from(error.to_string()));
428                        cx.notify();
429                    } else {
430                        for tool_call in tool_calls.iter() {
431                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
432                        }
433                    }
434                }
435            })?;
436
437            if tool_tasks.is_empty() {
438                return Ok(());
439            }
440
441            let tools = join_all(tool_tasks.into_iter()).await;
442            // If the WindowContext went away for any tool's view we don't include it
443            // especially since the below call would fail for the same reason.
444            let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
445
446            this.update(cx, |this, cx| {
447                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
448                    this.messages.last_mut()
449                {
450                    *tool_calls = tools;
451                    cx.notify();
452                }
453            })?;
454        }
455    }
456
457    fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
458        let message = ChatMessage::Assistant(AssistantMessage {
459            id: self.next_message_id.post_inc(),
460            body: RichText::default(),
461            tool_calls: Vec::new(),
462            error: None,
463        });
464        self.push_message(message, cx);
465    }
466
467    fn push_message(&mut self, message: ChatMessage, cx: &mut ViewContext<Self>) {
468        let old_len = self.messages.len();
469        let focus_handle = Some(message.focus_handle(cx));
470        self.messages.push(message);
471        self.list_state
472            .splice_focusable(old_len..old_len, focus_handle);
473        cx.notify();
474    }
475
476    fn pop_message(&mut self, cx: &mut ViewContext<Self>) {
477        if self.messages.is_empty() {
478            return;
479        }
480
481        self.messages.pop();
482        self.list_state
483            .splice(self.messages.len()..self.messages.len() + 1, 0);
484        cx.notify();
485    }
486
487    fn truncate_messages(&mut self, last_message_id: MessageId, cx: &mut ViewContext<Self>) {
488        if let Some(index) = self.messages.iter().position(|message| match message {
489            ChatMessage::User(message) => message.id == last_message_id,
490            ChatMessage::Assistant(message) => message.id == last_message_id,
491        }) {
492            self.list_state.splice(index + 1..self.messages.len(), 0);
493            self.messages.truncate(index + 1);
494            cx.notify();
495        }
496    }
497
498    fn render_error(
499        &self,
500        error: Option<SharedString>,
501        _ix: usize,
502        cx: &mut ViewContext<Self>,
503    ) -> AnyElement {
504        let theme = cx.theme();
505
506        if let Some(error) = error {
507            div()
508                .py_1()
509                .px_2()
510                .neg_mx_1()
511                .rounded_md()
512                .border()
513                .border_color(theme.status().error_border)
514                // .bg(theme.status().error_background)
515                .text_color(theme.status().error)
516                .child(error.clone())
517                .into_any_element()
518        } else {
519            div().into_any_element()
520        }
521    }
522
523    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
524        let is_last = ix == self.messages.len() - 1;
525
526        match &self.messages[ix] {
527            ChatMessage::User(UserMessage { body, .. }) => div()
528                .when(!is_last, |element| element.mb_2())
529                .child(div().p_2().child(Label::new("You").color(Color::Default)))
530                .child(
531                    div()
532                        .p_2()
533                        .text_color(cx.theme().colors().editor_foreground)
534                        .font(ThemeSettings::get_global(cx).buffer_font.clone())
535                        .bg(cx.theme().colors().editor_background)
536                        .child(body.clone()),
537                )
538                .into_any(),
539            ChatMessage::Assistant(AssistantMessage {
540                id,
541                body,
542                error,
543                tool_calls,
544                ..
545            }) => {
546                let assistant_body = if body.text.is_empty() && !tool_calls.is_empty() {
547                    div()
548                } else {
549                    div().p_2().child(body.element(ElementId::from(id.0), cx))
550                };
551
552                div()
553                    .when(!is_last, |element| element.mb_2())
554                    .child(
555                        div()
556                            .p_2()
557                            .child(Label::new("Assistant").color(Color::Modified)),
558                    )
559                    .child(assistant_body)
560                    .child(self.render_error(error.clone(), ix, cx))
561                    .children(tool_calls.iter().map(|tool_call| {
562                        let result = &tool_call.result;
563                        let name = tool_call.name.clone();
564                        match result {
565                            Some(result) => {
566                                div().p_2().child(result.into_any_element(&name)).into_any()
567                            }
568                            None => div()
569                                .p_2()
570                                .child(Label::new(name).color(Color::Modified))
571                                .child("Running...")
572                                .into_any(),
573                        }
574                    }))
575                    .into_any()
576            }
577        }
578    }
579
580    fn completion_messages(&self, cx: &mut WindowContext) -> Vec<CompletionMessage> {
581        let mut completion_messages = Vec::new();
582
583        for message in &self.messages {
584            match message {
585                ChatMessage::User(UserMessage { body, .. }) => {
586                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
587                    // contexts.iter().for_each(|context| {
588                    //     completion_messages.extend(context.completion_messages(cx))
589                    // });
590
591                    // Show user's message last so that the assistant is grounded in the user's request
592                    completion_messages.push(CompletionMessage::User {
593                        content: body.read(cx).text(cx),
594                    });
595                }
596                ChatMessage::Assistant(AssistantMessage {
597                    body, tool_calls, ..
598                }) => {
599                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
600                    // not break the Chat API if it does.
601                    if body.text.is_empty() && tool_calls.is_empty() {
602                        continue;
603                    }
604
605                    let tool_calls_from_assistant = tool_calls
606                        .iter()
607                        .map(|tool_call| ToolCall {
608                            content: ToolCallContent::Function {
609                                function: FunctionContent {
610                                    name: tool_call.name.clone(),
611                                    arguments: tool_call.arguments.clone(),
612                                },
613                            },
614                            id: tool_call.id.clone(),
615                        })
616                        .collect();
617
618                    completion_messages.push(CompletionMessage::Assistant {
619                        content: Some(body.text.to_string()),
620                        tool_calls: tool_calls_from_assistant,
621                    });
622
623                    for tool_call in tool_calls {
624                        // todo!(): we should not be sending when the tool is still running / has no result
625                        // For now I'm going to have to assume we send an empty string because otherwise
626                        // the Chat API will break -- there is a required message for every tool call by ID
627                        let content = match &tool_call.result {
628                            Some(result) => result.format(&tool_call.name),
629                            None => "".to_string(),
630                        };
631
632                        completion_messages.push(CompletionMessage::Tool {
633                            content,
634                            tool_call_id: tool_call.id.clone(),
635                        });
636                    }
637                }
638            }
639        }
640
641        completion_messages
642    }
643}
644
645impl Render for AssistantChat {
646    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
647        div()
648            .relative()
649            .flex_1()
650            .v_flex()
651            .key_context("AssistantChat")
652            .on_action(cx.listener(Self::submit))
653            .on_action(cx.listener(Self::cancel))
654            .text_color(Color::Default.color(cx))
655            .child(list(self.list_state.clone()).flex_1())
656            .child(Composer::new(
657                cx.view().downgrade(),
658                self.model.clone(),
659                self.composer_editor.clone(),
660                self.user_store.read(cx).current_user(),
661                self.can_submit(),
662                self.tool_registry.clone(),
663            ))
664    }
665}
666
667#[derive(Debug, Copy, Clone, Eq, PartialEq)]
668struct MessageId(usize);
669
670impl MessageId {
671    fn post_inc(&mut self) -> Self {
672        let id = *self;
673        self.0 += 1;
674        id
675    }
676}
677
678enum ChatMessage {
679    User(UserMessage),
680    Assistant(AssistantMessage),
681}
682
683impl ChatMessage {
684    fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
685        match self {
686            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
687            ChatMessage::Assistant(_) => None,
688        }
689    }
690}
691
692struct UserMessage {
693    id: MessageId,
694    body: View<Editor>,
695}
696
697struct AssistantMessage {
698    id: MessageId,
699    body: RichText,
700    tool_calls: Vec<ToolFunctionCall>,
701    error: Option<SharedString>,
702}