inline_completion_registry.rs

  1use std::{cell::RefCell, rc::Rc, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use collections::HashMap;
  5use copilot::{Copilot, CopilotCompletionProvider};
  6use editor::{Editor, EditorMode};
  7use feature_flags::{FeatureFlagAppExt, PredictEditsFeatureFlag};
  8use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity, Window};
  9use language::language_settings::{all_language_settings, InlineCompletionProvider};
 10use settings::SettingsStore;
 11use supermaven::{Supermaven, SupermavenCompletionProvider};
 12use zed_predict_tos::ZedPredictTos;
 13
 14pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
 15    let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
 16    cx.observe_new({
 17        let editors = editors.clone();
 18        let client = client.clone();
 19        let user_store = user_store.clone();
 20        move |editor: &mut Editor, window, cx: &mut Context<Editor>| {
 21            if editor.mode() != EditorMode::Full {
 22                return;
 23            }
 24
 25            register_backward_compatible_actions(editor, cx);
 26
 27            let Some(window) = window else {
 28                return;
 29            };
 30
 31            let editor_handle = cx.entity().downgrade();
 32            cx.on_release({
 33                let editor_handle = editor_handle.clone();
 34                let editors = editors.clone();
 35                move |_, _| {
 36                    editors.borrow_mut().remove(&editor_handle);
 37                }
 38            })
 39            .detach();
 40            editors
 41                .borrow_mut()
 42                .insert(editor_handle, window.window_handle());
 43            let provider = all_language_settings(None, cx).inline_completions.provider;
 44            assign_inline_completion_provider(
 45                editor,
 46                provider,
 47                &client,
 48                user_store.clone(),
 49                window,
 50                cx,
 51            );
 52        }
 53    })
 54    .detach();
 55
 56    let mut provider = all_language_settings(None, cx).inline_completions.provider;
 57    for (editor, window) in editors.borrow().iter() {
 58        _ = window.update(cx, |_window, window, cx| {
 59            _ = editor.update(cx, |editor, cx| {
 60                assign_inline_completion_provider(
 61                    editor,
 62                    provider,
 63                    &client,
 64                    user_store.clone(),
 65                    window,
 66                    cx,
 67                );
 68            })
 69        });
 70    }
 71
 72    if cx.has_flag::<PredictEditsFeatureFlag>() {
 73        cx.on_action(clear_zeta_edit_history);
 74    }
 75
 76    cx.observe_flag::<PredictEditsFeatureFlag, _>({
 77        let editors = editors.clone();
 78        let client = client.clone();
 79        let user_store = user_store.clone();
 80        move |active, cx| {
 81            let provider = all_language_settings(None, cx).inline_completions.provider;
 82            assign_inline_completion_providers(&editors, provider, &client, user_store.clone(), cx);
 83            if active && !cx.is_action_available(&zeta::ClearHistory) {
 84                cx.on_action(clear_zeta_edit_history);
 85            }
 86        }
 87    })
 88    .detach();
 89
 90    cx.observe_global::<SettingsStore>({
 91        let editors = editors.clone();
 92        let client = client.clone();
 93        let user_store = user_store.clone();
 94        move |cx| {
 95            let new_provider = all_language_settings(None, cx).inline_completions.provider;
 96            if new_provider != provider {
 97                provider = new_provider;
 98                assign_inline_completion_providers(
 99                    &editors,
100                    provider,
101                    &client,
102                    user_store.clone(),
103                    cx,
104                );
105
106                if !user_store
107                    .read(cx)
108                    .current_user_has_accepted_terms()
109                    .unwrap_or(false)
110                {
111                    match provider {
112                        InlineCompletionProvider::Zed => {
113                            let Some(window) = cx.active_window() else {
114                                return;
115                            };
116
117                            let Some(Some(workspace)) = window
118                                .update(cx, |_, window, _| window.root().flatten())
119                                .ok()
120                            else {
121                                return;
122                            };
123
124                            window
125                                .update(cx, |_, window, cx| {
126                                    ZedPredictTos::toggle(
127                                        workspace,
128                                        user_store.clone(),
129                                        window,
130                                        cx,
131                                    );
132                                })
133                                .ok();
134                        }
135                        InlineCompletionProvider::None
136                        | InlineCompletionProvider::Copilot
137                        | InlineCompletionProvider::Supermaven => {}
138                    }
139                }
140            }
141        }
142    })
143    .detach();
144}
145
146fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
147    if let Some(zeta) = zeta::Zeta::global(cx) {
148        zeta.update(cx, |zeta, _| zeta.clear_history());
149    }
150}
151
152fn assign_inline_completion_providers(
153    editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
154    provider: InlineCompletionProvider,
155    client: &Arc<Client>,
156    user_store: Entity<UserStore>,
157    cx: &mut App,
158) {
159    for (editor, window) in editors.borrow().iter() {
160        _ = window.update(cx, |_window, window, cx| {
161            _ = editor.update(cx, |editor, cx| {
162                assign_inline_completion_provider(
163                    editor,
164                    provider,
165                    &client,
166                    user_store.clone(),
167                    window,
168                    cx,
169                );
170            })
171        });
172    }
173}
174
175fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Editor>) {
176    // We renamed some of these actions to not be copilot-specific, but that
177    // would have not been backwards-compatible. So here we are re-registering
178    // the actions with the old names to not break people's keymaps.
179    editor
180        .register_action(cx.listener(
181            |editor, _: &copilot::Suggest, window: &mut Window, cx: &mut Context<Editor>| {
182                editor.show_inline_completion(&Default::default(), window, cx);
183            },
184        ))
185        .detach();
186    editor
187        .register_action(cx.listener(
188            |editor, _: &copilot::NextSuggestion, window: &mut Window, cx: &mut Context<Editor>| {
189                editor.next_inline_completion(&Default::default(), window, cx);
190            },
191        ))
192        .detach();
193    editor
194        .register_action(cx.listener(
195            |editor,
196             _: &copilot::PreviousSuggestion,
197             window: &mut Window,
198             cx: &mut Context<Editor>| {
199                editor.previous_inline_completion(&Default::default(), window, cx);
200            },
201        ))
202        .detach();
203    editor
204        .register_action(cx.listener(
205            |editor,
206             _: &editor::actions::AcceptPartialCopilotSuggestion,
207             window: &mut Window,
208             cx: &mut Context<Editor>| {
209                editor.accept_partial_inline_completion(&Default::default(), window, cx);
210            },
211        ))
212        .detach();
213}
214
215fn assign_inline_completion_provider(
216    editor: &mut Editor,
217    provider: language::language_settings::InlineCompletionProvider,
218    client: &Arc<Client>,
219    user_store: Entity<UserStore>,
220    window: &mut Window,
221    cx: &mut Context<Editor>,
222) {
223    match provider {
224        language::language_settings::InlineCompletionProvider::None => {}
225        language::language_settings::InlineCompletionProvider::Copilot => {
226            if let Some(copilot) = Copilot::global(cx) {
227                if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
228                    if buffer.read(cx).file().is_some() {
229                        copilot.update(cx, |copilot, cx| {
230                            copilot.register_buffer(&buffer, cx);
231                        });
232                    }
233                }
234                let provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
235                editor.set_inline_completion_provider(Some(provider), window, cx);
236            }
237        }
238        language::language_settings::InlineCompletionProvider::Supermaven => {
239            if let Some(supermaven) = Supermaven::global(cx) {
240                let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven));
241                editor.set_inline_completion_provider(Some(provider), window, cx);
242            }
243        }
244
245        language::language_settings::InlineCompletionProvider::Zed => {
246            if cx.has_flag::<PredictEditsFeatureFlag>()
247                || (cfg!(debug_assertions) && client.status().borrow().is_connected())
248            {
249                let zeta = zeta::Zeta::register(client.clone(), user_store, cx);
250                if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
251                    if buffer.read(cx).file().is_some() {
252                        zeta.update(cx, |zeta, cx| {
253                            zeta.register_buffer(&buffer, cx);
254                        });
255                    }
256                }
257                let provider = cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta));
258                editor.set_inline_completion_provider(Some(provider), window, cx);
259            }
260        }
261    }
262}