1use std::{ops::Range, sync::Arc};
2
3use gpui::{App, Entity, EntityId, Task, prelude::*};
4
5use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
6use language::{Anchor, ToPoint};
7
8pub struct Zeta2EditPredictionProvider {
9 current: Option<CurrentEditPrediction>,
10 pending: Option<Task<()>>,
11}
12
13impl Zeta2EditPredictionProvider {
14 pub fn new() -> Self {
15 Self {
16 current: None,
17 pending: None,
18 }
19 }
20}
21
22#[derive(Clone)]
23struct CurrentEditPrediction {
24 buffer_id: EntityId,
25 prediction: EditPrediction,
26}
27
28impl EditPredictionProvider for Zeta2EditPredictionProvider {
29 fn name() -> &'static str {
30 // TODO [zeta2]
31 "zed-predict2"
32 }
33
34 fn display_name() -> &'static str {
35 "Zed's Edit Predictions 2"
36 }
37
38 fn show_completions_in_menu() -> bool {
39 true
40 }
41
42 fn show_tab_accept_marker() -> bool {
43 true
44 }
45
46 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
47 // TODO [zeta2]
48 DataCollectionState::Unsupported
49 }
50
51 fn toggle_data_collection(&mut self, _cx: &mut App) {
52 // TODO [zeta2]
53 }
54
55 fn usage(&self, _cx: &App) -> Option<client::EditPredictionUsage> {
56 // TODO [zeta2]
57 None
58 }
59
60 fn is_enabled(
61 &self,
62 _buffer: &Entity<language::Buffer>,
63 _cursor_position: language::Anchor,
64 _cx: &App,
65 ) -> bool {
66 true
67 }
68
69 fn is_refreshing(&self) -> bool {
70 self.pending.is_some()
71 }
72
73 fn refresh(
74 &mut self,
75 _project: Option<Entity<project::Project>>,
76 buffer: Entity<language::Buffer>,
77 cursor_position: language::Anchor,
78 _debounce: bool,
79 cx: &mut Context<Self>,
80 ) {
81 // TODO [zeta2] check account
82 // TODO [zeta2] actually request completion / interpolate
83
84 let snapshot = buffer.read(cx).snapshot();
85 let point = cursor_position.to_point(&snapshot);
86 let end_anchor = snapshot.anchor_before(language::Point::new(
87 point.row,
88 snapshot.line_len(point.row),
89 ));
90
91 let edits: Arc<[(Range<Anchor>, String)]> =
92 vec![(cursor_position..end_anchor, "👻".to_string())].into();
93 let edits_preview_task = buffer.read(cx).preview_edits(edits.clone(), cx);
94
95 // TODO [zeta2] throttle
96 // TODO [zeta2] keep 2 requests
97 self.pending = Some(cx.spawn(async move |this, cx| {
98 let edits_preview = edits_preview_task.await;
99
100 this.update(cx, |this, cx| {
101 this.current = Some(CurrentEditPrediction {
102 buffer_id: buffer.entity_id(),
103 prediction: EditPrediction {
104 // TODO! [zeta2] request id?
105 id: None,
106 edits: edits.to_vec(),
107 edit_preview: Some(edits_preview),
108 },
109 });
110 this.pending.take();
111 cx.notify();
112 })
113 .ok();
114 }));
115 cx.notify();
116 }
117
118 fn cycle(
119 &mut self,
120 _buffer: Entity<language::Buffer>,
121 _cursor_position: language::Anchor,
122 _direction: Direction,
123 _cx: &mut Context<Self>,
124 ) {
125 }
126
127 fn accept(&mut self, _cx: &mut Context<Self>) {
128 // TODO [zeta2] report accept
129 self.current.take();
130 self.pending.take();
131 }
132
133 fn discard(&mut self, _cx: &mut Context<Self>) {
134 self.current.take();
135 self.pending.take();
136 }
137
138 fn suggest(
139 &mut self,
140 buffer: &Entity<language::Buffer>,
141 _cursor_position: language::Anchor,
142 _cx: &mut Context<Self>,
143 ) -> Option<EditPrediction> {
144 let current_prediction = self.current.take()?;
145
146 if current_prediction.buffer_id != buffer.entity_id() {
147 return None;
148 }
149
150 // TODO [zeta2] interpolate
151
152 Some(current_prediction.prediction)
153 }
154}