terminal_codegen.rs

  1use crate::inline_prompt_editor::CodegenStatus;
  2use client::telemetry::Telemetry;
  3use futures::{SinkExt, StreamExt, channel::mpsc};
  4use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  5use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
  6use std::{sync::Arc, time::Instant};
  7use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
  8use terminal::Terminal;
  9
 10pub struct TerminalCodegen {
 11    pub status: CodegenStatus,
 12    pub telemetry: Option<Arc<Telemetry>>,
 13    terminal: Entity<Terminal>,
 14    generation: Task<()>,
 15    pub message_id: Option<String>,
 16    transaction: Option<TerminalTransaction>,
 17}
 18
 19impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 20
 21impl TerminalCodegen {
 22    pub fn new(terminal: Entity<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
 23        Self {
 24            terminal,
 25            telemetry,
 26            status: CodegenStatus::Idle,
 27            generation: Task::ready(()),
 28            message_id: None,
 29            transaction: None,
 30        }
 31    }
 32
 33    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
 34        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 35            return;
 36        };
 37
 38        let model_api_key = model.api_key(cx);
 39        let http_client = cx.http_client();
 40        let telemetry = self.telemetry.clone();
 41        self.status = CodegenStatus::Pending;
 42        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 43        self.generation = cx.spawn(async move |this, cx| {
 44            let model_telemetry_id = model.telemetry_id();
 45            let model_provider_id = model.provider_id();
 46            let response = model.stream_completion_text(prompt, &cx).await;
 47            let generate = async {
 48                let message_id = response
 49                    .as_ref()
 50                    .ok()
 51                    .and_then(|response| response.message_id.clone());
 52
 53                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 54
 55                let task = cx.background_spawn({
 56                    let message_id = message_id.clone();
 57                    let executor = cx.background_executor().clone();
 58                    async move {
 59                        let mut response_latency = None;
 60                        let request_start = Instant::now();
 61                        let task = async {
 62                            let mut chunks = response?.stream;
 63                            while let Some(chunk) = chunks.next().await {
 64                                if response_latency.is_none() {
 65                                    response_latency = Some(request_start.elapsed());
 66                                }
 67                                let chunk = chunk?;
 68                                hunks_tx.send(chunk).await?;
 69                            }
 70
 71                            anyhow::Ok(())
 72                        };
 73
 74                        let result = task.await;
 75
 76                        let error_message = result.as_ref().err().map(|error| error.to_string());
 77                        report_assistant_event(
 78                            AssistantEvent {
 79                                conversation_id: None,
 80                                kind: AssistantKind::InlineTerminal,
 81                                message_id,
 82                                phase: AssistantPhase::Response,
 83                                model: model_telemetry_id,
 84                                model_provider: model_provider_id.to_string(),
 85                                response_latency,
 86                                error_message,
 87                                language_name: None,
 88                            },
 89                            telemetry,
 90                            http_client,
 91                            model_api_key,
 92                            &executor,
 93                        );
 94
 95                        result?;
 96                        anyhow::Ok(())
 97                    }
 98                });
 99
100                this.update(cx, |this, _| {
101                    this.message_id = message_id;
102                })?;
103
104                while let Some(hunk) = hunks_rx.next().await {
105                    this.update(cx, |this, cx| {
106                        if let Some(transaction) = &mut this.transaction {
107                            transaction.push(hunk, cx);
108                            cx.notify();
109                        }
110                    })?;
111                }
112
113                task.await?;
114                anyhow::Ok(())
115            };
116
117            let result = generate.await;
118
119            this.update(cx, |this, cx| {
120                if let Err(error) = result {
121                    this.status = CodegenStatus::Error(error);
122                } else {
123                    this.status = CodegenStatus::Done;
124                }
125                cx.emit(CodegenEvent::Finished);
126                cx.notify();
127            })
128            .ok();
129        });
130        cx.notify();
131    }
132
133    pub fn stop(&mut self, cx: &mut Context<Self>) {
134        self.status = CodegenStatus::Done;
135        self.generation = Task::ready(());
136        cx.emit(CodegenEvent::Finished);
137        cx.notify();
138    }
139
140    pub fn complete(&mut self, cx: &mut Context<Self>) {
141        if let Some(transaction) = self.transaction.take() {
142            transaction.complete(cx);
143        }
144    }
145
146    pub fn undo(&mut self, cx: &mut Context<Self>) {
147        if let Some(transaction) = self.transaction.take() {
148            transaction.undo(cx);
149        }
150    }
151}
152
153#[derive(Copy, Clone, Debug)]
154pub enum CodegenEvent {
155    Finished,
156}
157
158#[cfg(not(target_os = "windows"))]
159pub const CLEAR_INPUT: &str = "\x15";
160#[cfg(target_os = "windows")]
161pub const CLEAR_INPUT: &str = "\x03";
162const CARRIAGE_RETURN: &str = "\x0d";
163
164struct TerminalTransaction {
165    terminal: Entity<Terminal>,
166}
167
168impl TerminalTransaction {
169    pub fn start(terminal: Entity<Terminal>) -> Self {
170        Self { terminal }
171    }
172
173    pub fn push(&mut self, hunk: String, cx: &mut App) {
174        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
175        let input = Self::sanitize_input(hunk);
176        self.terminal
177            .update(cx, |terminal, _| terminal.input(input));
178    }
179
180    pub fn undo(&self, cx: &mut App) {
181        self.terminal
182            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
183    }
184
185    pub fn complete(&self, cx: &mut App) {
186        self.terminal.update(cx, |terminal, _| {
187            terminal.input(CARRIAGE_RETURN.to_string())
188        });
189    }
190
191    fn sanitize_input(input: String) -> String {
192        input.replace(['\r', '\n'], "")
193    }
194}