1use crate::inline_prompt_editor::CodegenStatus;
  2use client::telemetry::Telemetry;
  3use futures::{SinkExt, StreamExt, channel::mpsc};
  4use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  5use language_model::{
  6    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
  7};
  8use std::{sync::Arc, time::Instant};
  9use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 10use terminal::Terminal;
 11
 12pub struct TerminalCodegen {
 13    pub status: CodegenStatus,
 14    pub telemetry: Option<Arc<Telemetry>>,
 15    terminal: Entity<Terminal>,
 16    generation: Task<()>,
 17    pub message_id: Option<String>,
 18    transaction: Option<TerminalTransaction>,
 19}
 20
 21impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 22
 23impl TerminalCodegen {
 24    pub fn new(terminal: Entity<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
 25        Self {
 26            terminal,
 27            telemetry,
 28            status: CodegenStatus::Idle,
 29            generation: Task::ready(()),
 30            message_id: None,
 31            transaction: None,
 32        }
 33    }
 34
 35    pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
 36        let Some(ConfiguredModel { model, .. }) =
 37            LanguageModelRegistry::read_global(cx).inline_assistant_model()
 38        else {
 39            return;
 40        };
 41
 42        let model_api_key = model.api_key(cx);
 43        let http_client = cx.http_client();
 44        let telemetry = self.telemetry.clone();
 45        self.status = CodegenStatus::Pending;
 46        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 47        self.generation = cx.spawn(async move |this, cx| {
 48            let prompt = prompt_task.await;
 49            let model_telemetry_id = model.telemetry_id();
 50            let model_provider_id = model.provider_id();
 51            let response = model.stream_completion_text(prompt, cx).await;
 52            let generate = async {
 53                let message_id = response
 54                    .as_ref()
 55                    .ok()
 56                    .and_then(|response| response.message_id.clone());
 57
 58                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 59
 60                let task = cx.background_spawn({
 61                    let message_id = message_id.clone();
 62                    let executor = cx.background_executor().clone();
 63                    async move {
 64                        let mut response_latency = None;
 65                        let request_start = Instant::now();
 66                        let task = async {
 67                            let mut chunks = response?.stream;
 68                            while let Some(chunk) = chunks.next().await {
 69                                if response_latency.is_none() {
 70                                    response_latency = Some(request_start.elapsed());
 71                                }
 72                                let chunk = chunk?;
 73                                hunks_tx.send(chunk).await?;
 74                            }
 75
 76                            anyhow::Ok(())
 77                        };
 78
 79                        let result = task.await;
 80
 81                        let error_message = result.as_ref().err().map(|error| error.to_string());
 82                        report_assistant_event(
 83                            AssistantEventData {
 84                                conversation_id: None,
 85                                kind: AssistantKind::InlineTerminal,
 86                                message_id,
 87                                phase: AssistantPhase::Response,
 88                                model: model_telemetry_id,
 89                                model_provider: model_provider_id.to_string(),
 90                                response_latency,
 91                                error_message,
 92                                language_name: None,
 93                            },
 94                            telemetry,
 95                            http_client,
 96                            model_api_key,
 97                            &executor,
 98                        );
 99
100                        result?;
101                        anyhow::Ok(())
102                    }
103                });
104
105                this.update(cx, |this, _| {
106                    this.message_id = message_id;
107                })?;
108
109                while let Some(hunk) = hunks_rx.next().await {
110                    this.update(cx, |this, cx| {
111                        if let Some(transaction) = &mut this.transaction {
112                            transaction.push(hunk, cx);
113                            cx.notify();
114                        }
115                    })?;
116                }
117
118                task.await?;
119                anyhow::Ok(())
120            };
121
122            let result = generate.await;
123
124            this.update(cx, |this, cx| {
125                if let Err(error) = result {
126                    this.status = CodegenStatus::Error(error);
127                } else {
128                    this.status = CodegenStatus::Done;
129                }
130                cx.emit(CodegenEvent::Finished);
131                cx.notify();
132            })
133            .ok();
134        });
135        cx.notify();
136    }
137
138    pub fn stop(&mut self, cx: &mut Context<Self>) {
139        self.status = CodegenStatus::Done;
140        self.generation = Task::ready(());
141        cx.emit(CodegenEvent::Finished);
142        cx.notify();
143    }
144
145    pub fn complete(&mut self, cx: &mut Context<Self>) {
146        if let Some(transaction) = self.transaction.take() {
147            transaction.complete(cx);
148        }
149    }
150
151    pub fn undo(&mut self, cx: &mut Context<Self>) {
152        if let Some(transaction) = self.transaction.take() {
153            transaction.undo(cx);
154        }
155    }
156}
157
158#[derive(Copy, Clone, Debug)]
159pub enum CodegenEvent {
160    Finished,
161}
162
163#[cfg(not(target_os = "windows"))]
164pub const CLEAR_INPUT: &str = "\x15";
165#[cfg(target_os = "windows")]
166pub const CLEAR_INPUT: &str = "\x03";
167const CARRIAGE_RETURN: &str = "\x0d";
168
169struct TerminalTransaction {
170    terminal: Entity<Terminal>,
171}
172
173impl TerminalTransaction {
174    pub fn start(terminal: Entity<Terminal>) -> Self {
175        Self { terminal }
176    }
177
178    pub fn push(&mut self, hunk: String, cx: &mut App) {
179        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
180        let input = Self::sanitize_input(hunk);
181        self.terminal
182            .update(cx, |terminal, _| terminal.input(input.into_bytes()));
183    }
184
185    pub fn undo(&self, cx: &mut App) {
186        self.terminal
187            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
188    }
189
190    pub fn complete(&self, cx: &mut App) {
191        self.terminal
192            .update(cx, |terminal, _| terminal.input(CARRIAGE_RETURN.as_bytes()));
193    }
194
195    fn sanitize_input(mut input: String) -> String {
196        input.retain(|c| c != '\r' && c != '\n');
197        input
198    }
199}