terminal_codegen.rs

  1use crate::inline_prompt_editor::CodegenStatus;
  2use futures::{SinkExt, StreamExt, channel::mpsc};
  3use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  4use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequest};
  5use std::time::Instant;
  6use terminal::Terminal;
  7use uuid::Uuid;
  8
  9pub struct TerminalCodegen {
 10    pub status: CodegenStatus,
 11    terminal: Entity<Terminal>,
 12    generation: Task<()>,
 13    pub message_id: Option<String>,
 14    transaction: Option<TerminalTransaction>,
 15    session_id: Uuid,
 16}
 17
 18impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 19
 20impl TerminalCodegen {
 21    pub fn new(terminal: Entity<Terminal>, session_id: Uuid) -> Self {
 22        Self {
 23            terminal,
 24            status: CodegenStatus::Idle,
 25            generation: Task::ready(()),
 26            message_id: None,
 27            transaction: None,
 28            session_id,
 29        }
 30    }
 31
 32    pub fn session_id(&self) -> Uuid {
 33        self.session_id
 34    }
 35
 36    pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
 37        let Some(ConfiguredModel { model, .. }) =
 38            LanguageModelRegistry::read_global(cx).inline_assistant_model()
 39        else {
 40            return;
 41        };
 42
 43        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 44        let session_id = self.session_id;
 45        let model_telemetry_id = model.telemetry_id();
 46        let model_provider_id = model.provider_id().to_string();
 47
 48        self.status = CodegenStatus::Pending;
 49        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 50        self.generation = cx.spawn(async move |this, cx| {
 51            let prompt = prompt_task.await;
 52            let response = model.stream_completion_text(prompt, cx).await;
 53            let generate = async {
 54                let message_id = response
 55                    .as_ref()
 56                    .ok()
 57                    .and_then(|response| response.message_id.clone());
 58
 59                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 60
 61                let task = cx.background_spawn({
 62                    let message_id = message_id.clone();
 63                    let anthropic_reporter = anthropic_reporter.clone();
 64                    async move {
 65                        let mut response_latency = None;
 66                        let request_start = Instant::now();
 67                        let task = async {
 68                            let mut chunks = response?.stream;
 69                            while let Some(chunk) = chunks.next().await {
 70                                if response_latency.is_none() {
 71                                    response_latency = Some(request_start.elapsed());
 72                                }
 73                                let chunk = chunk?;
 74                                hunks_tx.send(chunk).await?;
 75                            }
 76
 77                            anyhow::Ok(())
 78                        };
 79
 80                        let result = task.await;
 81
 82                        let error_message = result.as_ref().err().map(|error| error.to_string());
 83
 84                        telemetry::event!(
 85                            "Assistant Responded",
 86                            session_id = session_id.to_string(),
 87                            kind = "inline_terminal",
 88                            phase = "response",
 89                            model = model_telemetry_id,
 90                            model_provider = model_provider_id,
 91                            language_name = Option::<&str>::None,
 92                            message_id = message_id,
 93                            response_latency = response_latency,
 94                            error_message = error_message,
 95                        );
 96
 97                        anthropic_reporter.report(language_model::AnthropicEventData {
 98                            completion_type: language_model::AnthropicCompletionType::Terminal,
 99                            event: language_model::AnthropicEventType::Response,
100                            language_name: None,
101                            message_id,
102                        });
103
104                        result?;
105                        anyhow::Ok(())
106                    }
107                });
108
109                this.update(cx, |this, _| {
110                    this.message_id = message_id;
111                })?;
112
113                while let Some(hunk) = hunks_rx.next().await {
114                    this.update(cx, |this, cx| {
115                        if let Some(transaction) = &mut this.transaction {
116                            transaction.push(hunk, cx);
117                            cx.notify();
118                        }
119                    })?;
120                }
121
122                task.await?;
123                anyhow::Ok(())
124            };
125
126            let result = generate.await;
127
128            this.update(cx, |this, cx| {
129                if let Err(error) = result {
130                    this.status = CodegenStatus::Error(error);
131                } else {
132                    this.status = CodegenStatus::Done;
133                }
134                cx.emit(CodegenEvent::Finished);
135                cx.notify();
136            })
137            .ok();
138        });
139        cx.notify();
140    }
141
142    pub fn completion(&self) -> Option<String> {
143        self.transaction
144            .as_ref()
145            .map(|transaction| transaction.completion.clone())
146    }
147
148    pub fn stop(&mut self, cx: &mut Context<Self>) {
149        self.status = CodegenStatus::Done;
150        self.generation = Task::ready(());
151        cx.emit(CodegenEvent::Finished);
152        cx.notify();
153    }
154
155    pub fn complete(&mut self, cx: &mut Context<Self>) {
156        if let Some(transaction) = self.transaction.take() {
157            transaction.complete(cx);
158        }
159    }
160
161    pub fn undo(&mut self, cx: &mut Context<Self>) {
162        if let Some(transaction) = self.transaction.take() {
163            transaction.undo(cx);
164        }
165    }
166}
167
168#[derive(Copy, Clone, Debug)]
169pub enum CodegenEvent {
170    Finished,
171}
172
173#[cfg(not(target_os = "windows"))]
174pub const CLEAR_INPUT: &str = "\x15";
175#[cfg(target_os = "windows")]
176pub const CLEAR_INPUT: &str = "\x03";
177const CARRIAGE_RETURN: &str = "\x0d";
178
179struct TerminalTransaction {
180    completion: String,
181    terminal: Entity<Terminal>,
182}
183
184impl TerminalTransaction {
185    pub fn start(terminal: Entity<Terminal>) -> Self {
186        Self {
187            completion: String::new(),
188            terminal,
189        }
190    }
191
192    pub fn push(&mut self, hunk: String, cx: &mut App) {
193        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
194        let input = Self::sanitize_input(hunk);
195        self.completion.push_str(&input);
196        self.terminal
197            .update(cx, |terminal, _| terminal.input(input.into_bytes()));
198    }
199
200    pub fn undo(self, cx: &mut App) {
201        self.terminal
202            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
203    }
204
205    pub fn complete(self, cx: &mut App) {
206        self.terminal
207            .update(cx, |terminal, _| terminal.input(CARRIAGE_RETURN.as_bytes()));
208    }
209
210    fn sanitize_input(mut input: String) -> String {
211        input.retain(|c| c != '\r' && c != '\n');
212        input
213    }
214}