terminal_inline_assistant.rs

  1use crate::{
  2    context::load_context,
  3    inline_prompt_editor::{
  4        CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5    },
  6    terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen},
  7};
  8use agent::HistoryStore;
  9use agent_settings::AgentSettings;
 10use anyhow::{Context as _, Result};
 11
 12use cloud_llm_client::CompletionIntent;
 13use collections::{HashMap, VecDeque};
 14use editor::{MultiBuffer, actions::SelectAll};
 15use fs::Fs;
 16use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 17use language::Buffer;
 18use language_model::{
 19    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 20    Role, report_anthropic_event,
 21};
 22use project::Project;
 23use prompt_store::{PromptBuilder, PromptStore};
 24use std::sync::Arc;
 25use terminal_view::TerminalView;
 26use ui::prelude::*;
 27use util::ResultExt;
 28use uuid::Uuid;
 29use workspace::{Toast, Workspace, notifications::NotificationId};
 30
 31pub fn init(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>, cx: &mut App) {
 32    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder));
 33}
 34
 35const DEFAULT_CONTEXT_LINES: usize = 50;
 36const PROMPT_HISTORY_MAX_LEN: usize = 20;
 37
 38pub struct TerminalInlineAssistant {
 39    next_assist_id: TerminalInlineAssistId,
 40    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 41    prompt_history: VecDeque<String>,
 42    fs: Arc<dyn Fs>,
 43    prompt_builder: Arc<PromptBuilder>,
 44}
 45
 46impl Global for TerminalInlineAssistant {}
 47
 48impl TerminalInlineAssistant {
 49    pub fn new(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>) -> Self {
 50        Self {
 51            next_assist_id: TerminalInlineAssistId::default(),
 52            assists: HashMap::default(),
 53            prompt_history: VecDeque::default(),
 54            fs,
 55            prompt_builder,
 56        }
 57    }
 58
 59    pub fn assist(
 60        &mut self,
 61        terminal_view: &Entity<TerminalView>,
 62        workspace: WeakEntity<Workspace>,
 63        project: WeakEntity<Project>,
 64        thread_store: Entity<HistoryStore>,
 65        prompt_store: Option<Entity<PromptStore>>,
 66        initial_prompt: Option<String>,
 67        window: &mut Window,
 68        cx: &mut App,
 69    ) {
 70        let terminal = terminal_view.read(cx).terminal().clone();
 71        let assist_id = self.next_assist_id.post_inc();
 72        let session_id = Uuid::new_v4();
 73        let prompt_buffer = cx.new(|cx| {
 74            MultiBuffer::singleton(
 75                cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)),
 76                cx,
 77            )
 78        });
 79        let codegen = cx.new(|_| TerminalCodegen::new(terminal, session_id));
 80
 81        let prompt_editor = cx.new(|cx| {
 82            PromptEditor::new_terminal(
 83                assist_id,
 84                self.prompt_history.clone(),
 85                prompt_buffer.clone(),
 86                codegen,
 87                session_id,
 88                self.fs.clone(),
 89                thread_store.clone(),
 90                prompt_store.clone(),
 91                project.clone(),
 92                workspace.clone(),
 93                window,
 94                cx,
 95            )
 96        });
 97        let prompt_editor_render = prompt_editor.clone();
 98        let block = terminal_view::BlockProperties {
 99            height: 4,
100            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
101        };
102        terminal_view.update(cx, |terminal_view, cx| {
103            terminal_view.set_block_below_cursor(block, window, cx);
104        });
105
106        let terminal_assistant = TerminalInlineAssist::new(
107            assist_id,
108            terminal_view,
109            prompt_editor,
110            workspace.clone(),
111            window,
112            cx,
113        );
114
115        self.assists.insert(assist_id, terminal_assistant);
116
117        self.focus_assist(assist_id, window, cx);
118    }
119
120    fn focus_assist(
121        &mut self,
122        assist_id: TerminalInlineAssistId,
123        window: &mut Window,
124        cx: &mut App,
125    ) {
126        let assist = &self.assists[&assist_id];
127        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
128            prompt_editor.update(cx, |this, cx| {
129                this.editor.update(cx, |editor, cx| {
130                    window.focus(&editor.focus_handle(cx));
131                    editor.select_all(&SelectAll, window, cx);
132                });
133            });
134        }
135    }
136
137    fn handle_prompt_editor_event(
138        &mut self,
139        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
140        event: &PromptEditorEvent,
141        window: &mut Window,
142        cx: &mut App,
143    ) {
144        let assist_id = prompt_editor.read(cx).id();
145        match event {
146            PromptEditorEvent::StartRequested => {
147                self.start_assist(assist_id, cx);
148            }
149            PromptEditorEvent::StopRequested => {
150                self.stop_assist(assist_id, cx);
151            }
152            PromptEditorEvent::ConfirmRequested { execute } => {
153                self.finish_assist(assist_id, false, *execute, window, cx);
154            }
155            PromptEditorEvent::CancelRequested => {
156                self.finish_assist(assist_id, true, false, window, cx);
157            }
158            PromptEditorEvent::Resized { height_in_lines } => {
159                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
160            }
161        }
162    }
163
164    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
165        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
166            assist
167        } else {
168            return;
169        };
170
171        let Some(user_prompt) = assist
172            .prompt_editor
173            .as_ref()
174            .map(|editor| editor.read(cx).prompt(cx))
175        else {
176            return;
177        };
178
179        self.prompt_history.retain(|prompt| *prompt != user_prompt);
180        self.prompt_history.push_back(user_prompt);
181        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
182            self.prompt_history.pop_front();
183        }
184
185        assist
186            .terminal
187            .update(cx, |terminal, cx| {
188                terminal
189                    .terminal()
190                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
191            })
192            .log_err();
193
194        let codegen = assist.codegen.clone();
195        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
196            return;
197        };
198
199        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
200    }
201
202    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
203        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
204            assist
205        } else {
206            return;
207        };
208
209        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
210    }
211
212    fn request_for_inline_assist(
213        &self,
214        assist_id: TerminalInlineAssistId,
215        cx: &mut App,
216    ) -> Result<Task<LanguageModelRequest>> {
217        let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
218            .inline_assistant_model()
219            .context("No inline assistant model")?;
220
221        let assist = self.assists.get(&assist_id).context("invalid assist")?;
222
223        let shell = std::env::var("SHELL").ok();
224        let (latest_output, working_directory) = assist
225            .terminal
226            .update(cx, |terminal, cx| {
227                let terminal = terminal.entity().read(cx);
228                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
229                let working_directory = terminal
230                    .working_directory()
231                    .map(|path| path.to_string_lossy().into_owned());
232                (latest_output, working_directory)
233            })
234            .ok()
235            .unwrap_or_default();
236
237        let prompt_editor = assist.prompt_editor.clone().context("invalid assist")?;
238
239        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
240            &prompt_editor.read(cx).prompt(cx),
241            shell.as_deref(),
242            working_directory.as_deref(),
243            &latest_output,
244        )?;
245
246        let temperature = AgentSettings::temperature_for_model(&model, cx);
247
248        let mention_set = prompt_editor.read(cx).mention_set().clone();
249        let load_context_task = load_context(&mention_set, cx);
250
251        Ok(cx.background_spawn(async move {
252            let mut request_message = LanguageModelRequestMessage {
253                role: Role::User,
254                content: vec![],
255                cache: false,
256                reasoning_details: None,
257            };
258
259            if let Some(context) = load_context_task.await {
260                context.add_to_request_message(&mut request_message);
261            }
262
263            request_message.content.push(prompt.into());
264
265            LanguageModelRequest {
266                thread_id: None,
267                prompt_id: None,
268                mode: None,
269                intent: Some(CompletionIntent::TerminalInlineAssist),
270                messages: vec![request_message],
271                tools: Vec::new(),
272                tool_choice: None,
273                stop: Vec::new(),
274                temperature,
275                thinking_allowed: false,
276            }
277        }))
278    }
279
280    fn finish_assist(
281        &mut self,
282        assist_id: TerminalInlineAssistId,
283        undo: bool,
284        execute: bool,
285        window: &mut Window,
286        cx: &mut App,
287    ) {
288        self.dismiss_assist(assist_id, window, cx);
289
290        if let Some(assist) = self.assists.remove(&assist_id) {
291            assist
292                .terminal
293                .update(cx, |this, cx| {
294                    this.clear_block_below_cursor(cx);
295                    this.focus_handle(cx).focus(window);
296                })
297                .log_err();
298
299            if let Some(ConfiguredModel { model, .. }) =
300                LanguageModelRegistry::read_global(cx).inline_assistant_model()
301            {
302                let codegen = assist.codegen.read(cx);
303                let session_id = codegen.session_id();
304                let message_id = codegen.message_id.clone();
305                let model_telemetry_id = model.telemetry_id();
306                let model_provider_id = model.provider_id().to_string();
307
308                let (phase, event_type, anthropic_event_type) = if undo {
309                    (
310                        "rejected",
311                        "Assistant Response Rejected",
312                        language_model::AnthropicEventType::Reject,
313                    )
314                } else {
315                    (
316                        "accepted",
317                        "Assistant Response Accepted",
318                        language_model::AnthropicEventType::Accept,
319                    )
320                };
321
322                // Fire Zed telemetry
323                telemetry::event!(
324                    event_type,
325                    kind = "inline_terminal",
326                    phase = phase,
327                    model = model_telemetry_id,
328                    model_provider = model_provider_id,
329                    message_id = message_id,
330                    session_id = session_id,
331                );
332
333                report_anthropic_event(
334                    &model,
335                    language_model::AnthropicEventData {
336                        completion_type: language_model::AnthropicCompletionType::Terminal,
337                        event: anthropic_event_type,
338                        language_name: None,
339                        message_id,
340                    },
341                    cx,
342                );
343            }
344
345            assist.codegen.update(cx, |codegen, cx| {
346                if undo {
347                    codegen.undo(cx);
348                } else if execute {
349                    codegen.complete(cx);
350                }
351            });
352        }
353    }
354
355    fn dismiss_assist(
356        &mut self,
357        assist_id: TerminalInlineAssistId,
358        window: &mut Window,
359        cx: &mut App,
360    ) -> bool {
361        let Some(assist) = self.assists.get_mut(&assist_id) else {
362            return false;
363        };
364        if assist.prompt_editor.is_none() {
365            return false;
366        }
367        assist.prompt_editor = None;
368        assist
369            .terminal
370            .update(cx, |this, cx| {
371                this.clear_block_below_cursor(cx);
372                this.focus_handle(cx).focus(window);
373            })
374            .is_ok()
375    }
376
377    fn insert_prompt_editor_into_terminal(
378        &mut self,
379        assist_id: TerminalInlineAssistId,
380        height: u8,
381        window: &mut Window,
382        cx: &mut App,
383    ) {
384        if let Some(assist) = self.assists.get_mut(&assist_id)
385            && let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned()
386        {
387            assist
388                .terminal
389                .update(cx, |terminal, cx| {
390                    terminal.clear_block_below_cursor(cx);
391                    let block = terminal_view::BlockProperties {
392                        height,
393                        render: Box::new(move |_| prompt_editor.clone().into_any_element()),
394                    };
395                    terminal.set_block_below_cursor(block, window, cx);
396                })
397                .log_err();
398        }
399    }
400}
401
402struct TerminalInlineAssist {
403    terminal: WeakEntity<TerminalView>,
404    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
405    codegen: Entity<TerminalCodegen>,
406    workspace: WeakEntity<Workspace>,
407    _subscriptions: Vec<Subscription>,
408}
409
410impl TerminalInlineAssist {
411    pub fn new(
412        assist_id: TerminalInlineAssistId,
413        terminal: &Entity<TerminalView>,
414        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
415        workspace: WeakEntity<Workspace>,
416        window: &mut Window,
417        cx: &mut App,
418    ) -> Self {
419        let codegen = prompt_editor.read(cx).codegen().clone();
420        Self {
421            terminal: terminal.downgrade(),
422            prompt_editor: Some(prompt_editor.clone()),
423            codegen: codegen.clone(),
424            workspace,
425            _subscriptions: vec![
426                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
427                    TerminalInlineAssistant::update_global(cx, |this, cx| {
428                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
429                    })
430                }),
431                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
432                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
433                        CodegenEvent::Finished => {
434                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
435                                assist
436                            } else {
437                                return;
438                            };
439
440                            if let CodegenStatus::Error(error) = &codegen.read(cx).status
441                                && assist.prompt_editor.is_none()
442                                && let Some(workspace) = assist.workspace.upgrade()
443                            {
444                                let error = format!("Terminal inline assistant error: {}", error);
445                                workspace.update(cx, |workspace, cx| {
446                                    struct InlineAssistantError;
447
448                                    let id = NotificationId::composite::<InlineAssistantError>(
449                                        assist_id.0,
450                                    );
451
452                                    workspace.show_toast(Toast::new(id, error), cx);
453                                })
454                            }
455
456                            if assist.prompt_editor.is_none() {
457                                this.finish_assist(assist_id, false, false, window, cx);
458                            }
459                        }
460                    })
461                }),
462            ],
463        }
464    }
465}