1use std::ops::Range;
2use zed::lsp::{Completion, CompletionKind, Symbol, SymbolKind};
3use zed::{CodeLabel, CodeLabelSpan};
4use zed_extension_api::{self as zed, Result};
5
6const OPERATOR_CHAR: [char; 17] = [
7 '~', '!', '?', '%', '<', ':', '.', '$', '&', '*', '+', '-', '/', '=', '>', '@', '^',
8];
9
10struct OcamlExtension;
11
12impl zed::Extension for OcamlExtension {
13 fn new() -> Self {
14 Self
15 }
16
17 fn language_server_command(
18 &mut self,
19 _language_server_id: &zed::LanguageServerId,
20 worktree: &zed::Worktree,
21 ) -> Result<zed::Command> {
22 let path = worktree.which("ocamllsp").ok_or_else(|| {
23 "ocamllsp (ocaml-language-server) must be installed manually.".to_string()
24 })?;
25
26 Ok(zed::Command {
27 command: path,
28 args: Vec::new(),
29 env: worktree.shell_env(),
30 })
31 }
32
33 fn label_for_completion(
34 &self,
35 _language_server_id: &zed::LanguageServerId,
36 completion: Completion,
37 ) -> Option<CodeLabel> {
38 let name = &completion.label;
39 let detail = completion.detail.as_ref().map(|s| s.replace('\n', " "));
40
41 match completion.kind.zip(detail) {
42 Some((CompletionKind::Constructor | CompletionKind::EnumMember, detail)) => {
43 let (argument, return_t) = detail
44 .split_once("->")
45 .map_or((None, detail.as_str()), |(arg, typ)| {
46 (Some(arg.trim()), typ.trim())
47 });
48
49 let type_decl = "type t = ";
50 let type_of = argument.map(|_| " of ").unwrap_or_default();
51 let argument = argument.unwrap_or_default();
52 let terminator = "\n";
53 let let_decl = "let _ ";
54 let let_colon = ": ";
55 let let_suffix = " = ()";
56 let code = format!(
57 "{type_decl}{name}{type_of}{argument}{terminator}{let_decl}{let_colon}{return_t}{let_suffix}"
58 );
59
60 let name_start = type_decl.len();
61 let argument_end = name_start + name.len() + type_of.len() + argument.len();
62 let colon_start = argument_end + terminator.len() + let_decl.len();
63 let return_type_end = code.len() - let_suffix.len();
64 Some(CodeLabel {
65 code,
66 spans: vec![
67 CodeLabelSpan::code_range(name_start..argument_end),
68 CodeLabelSpan::code_range(colon_start..return_type_end),
69 ],
70 filter_range: (0..name.len()).into(),
71 })
72 }
73
74 Some((CompletionKind::Field, detail)) => {
75 let filter_range_start = if name.starts_with(['~', '?']) { 1 } else { 0 };
76
77 let record_prefix = "type t = { ";
78 let record_suffix = "; }";
79 let code = format!("{record_prefix}{name} : {detail}{record_suffix}");
80
81 Some(CodeLabel {
82 spans: vec![CodeLabelSpan::code_range(
83 record_prefix.len()..code.len() - record_suffix.len(),
84 )],
85 code,
86 filter_range: (filter_range_start..name.len()).into(),
87 })
88 }
89
90 Some((CompletionKind::Value, detail)) => {
91 let let_prefix = "let ";
92 let suffix = " = ()";
93 let (l_paren, r_paren) = if name.contains(OPERATOR_CHAR) {
94 ("( ", " )")
95 } else {
96 ("", "")
97 };
98 let code = format!("{let_prefix}{l_paren}{name}{r_paren} : {detail}{suffix}");
99
100 let name_start = let_prefix.len() + l_paren.len();
101 let name_end = name_start + name.len();
102 let type_annotation_start = name_end + r_paren.len();
103 let type_annotation_end = code.len() - suffix.len();
104
105 Some(CodeLabel {
106 spans: vec![
107 CodeLabelSpan::code_range(name_start..name_end),
108 CodeLabelSpan::code_range(type_annotation_start..type_annotation_end),
109 ],
110 filter_range: (0..name.len()).into(),
111 code,
112 })
113 }
114
115 Some((CompletionKind::Method, detail)) => {
116 let method_decl = "class c : object method ";
117 let end = " end";
118 let code = format!("{method_decl}{name} : {detail}{end}");
119
120 Some(CodeLabel {
121 spans: vec![CodeLabelSpan::code_range(
122 method_decl.len()..code.len() - end.len(),
123 )],
124 code,
125 filter_range: (0..name.len()).into(),
126 })
127 }
128
129 Some((kind, _)) => {
130 let highlight_name = match kind {
131 CompletionKind::Module | CompletionKind::Interface => "title",
132 CompletionKind::Keyword => "keyword",
133 CompletionKind::TypeParameter => "type",
134 _ => return None,
135 };
136
137 Some(CodeLabel {
138 spans: vec![(CodeLabelSpan::literal(name, Some(highlight_name.to_string())))],
139 filter_range: (0..name.len()).into(),
140 code: String::new(),
141 })
142 }
143 _ => None,
144 }
145 }
146
147 fn label_for_symbol(
148 &self,
149 _language_server_id: &zed::LanguageServerId,
150 symbol: Symbol,
151 ) -> Option<CodeLabel> {
152 let name = &symbol.name;
153
154 let (code, filter_range, display_range) = match symbol.kind {
155 SymbolKind::Property => {
156 let code = format!("type t = {{ {}: (); }}", name);
157 let filter_range: Range<usize> = 0..name.len();
158 let display_range = 11..11 + name.len();
159 (code, filter_range, display_range)
160 }
161 SymbolKind::Function
162 if name.contains(OPERATOR_CHAR)
163 || (name.starts_with("let") && name.contains(OPERATOR_CHAR)) =>
164 {
165 let code = format!("let ( {name} ) () = ()");
166
167 let filter_range = 6..6 + name.len();
168 let display_range = 0..filter_range.end + 1;
169 (code, filter_range, display_range)
170 }
171 SymbolKind::Function => {
172 let code = format!("let {name} () = ()");
173
174 let filter_range = 4..4 + name.len();
175 let display_range = 0..filter_range.end;
176 (code, filter_range, display_range)
177 }
178 SymbolKind::Constructor => {
179 let code = format!("type t = {name}");
180 let filter_range = 0..name.len();
181 let display_range = 9..9 + name.len();
182 (code, filter_range, display_range)
183 }
184 SymbolKind::Module => {
185 let code = format!("module {name} = struct end");
186 let filter_range = 7..7 + name.len();
187 let display_range = 0..filter_range.end;
188 (code, filter_range, display_range)
189 }
190 SymbolKind::Class => {
191 let code = format!("class {name} = object end");
192 let filter_range = 6..6 + name.len();
193 let display_range = 0..filter_range.end;
194 (code, filter_range, display_range)
195 }
196 SymbolKind::Method => {
197 let code = format!("class c = object method {name} = () end");
198 let filter_range = 0..name.len();
199 let display_range = 17..24 + name.len();
200 (code, filter_range, display_range)
201 }
202 SymbolKind::String => {
203 let code = format!("type {name} = T");
204 let filter_range = 5..5 + name.len();
205 let display_range = 0..filter_range.end;
206 (code, filter_range, display_range)
207 }
208 _ => return None,
209 };
210
211 Some(CodeLabel {
212 code,
213 spans: vec![CodeLabelSpan::code_range(display_range)],
214 filter_range: filter_range.into(),
215 })
216 }
217}
218
219zed::register_extension!(OcamlExtension);