1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
14use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
15use assistant_slash_command::SlashCommandRegistry;
16use client::{proto, Client};
17use command_palette_hooks::CommandPaletteFilter;
18pub(crate) use completion_provider::*;
19pub(crate) use context_store::*;
20use fs::Fs;
21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
22pub(crate) use inline_assistant::*;
23pub(crate) use model_selector::*;
24use rustdoc::RustdocStore;
25use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
26use serde::{Deserialize, Serialize};
27use settings::{Settings, SettingsStore};
28use slash_command::{
29 active_command, default_command, diagnostics_command, fetch_command, file_command, now_command,
30 project_command, prompt_command, rustdoc_command, search_command, tabs_command, term_command,
31};
32use std::{
33 fmt::{self, Display},
34 sync::Arc,
35};
36pub(crate) use streaming_diff::*;
37
38actions!(
39 assistant,
40 [
41 Assist,
42 Split,
43 CycleMessageRole,
44 QuoteSelection,
45 ToggleFocus,
46 ResetKey,
47 InlineAssist,
48 InsertActivePrompt,
49 ToggleHistory,
50 ApplyEdit,
51 ConfirmCommand,
52 ToggleModelSelector
53 ]
54);
55
56#[derive(
57 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
58)]
59struct MessageId(usize);
60
61#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
62#[serde(rename_all = "lowercase")]
63pub enum Role {
64 User,
65 Assistant,
66 System,
67}
68
69impl Role {
70 pub fn cycle(&mut self) {
71 *self = match self {
72 Role::User => Role::Assistant,
73 Role::Assistant => Role::System,
74 Role::System => Role::User,
75 }
76 }
77}
78
79impl Display for Role {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Role::User => write!(f, "user"),
83 Role::Assistant => write!(f, "assistant"),
84 Role::System => write!(f, "system"),
85 }
86 }
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
90pub enum LanguageModel {
91 Cloud(CloudModel),
92 OpenAi(OpenAiModel),
93 Anthropic(AnthropicModel),
94 Ollama(OllamaModel),
95}
96
97impl Default for LanguageModel {
98 fn default() -> Self {
99 LanguageModel::Cloud(CloudModel::default())
100 }
101}
102
103impl LanguageModel {
104 pub fn telemetry_id(&self) -> String {
105 match self {
106 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
107 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
108 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
109 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
110 }
111 }
112
113 pub fn display_name(&self) -> String {
114 match self {
115 LanguageModel::OpenAi(model) => model.display_name().into(),
116 LanguageModel::Anthropic(model) => model.display_name().into(),
117 LanguageModel::Cloud(model) => model.display_name().into(),
118 LanguageModel::Ollama(model) => model.display_name().into(),
119 }
120 }
121
122 pub fn max_token_count(&self) -> usize {
123 match self {
124 LanguageModel::OpenAi(model) => model.max_token_count(),
125 LanguageModel::Anthropic(model) => model.max_token_count(),
126 LanguageModel::Cloud(model) => model.max_token_count(),
127 LanguageModel::Ollama(model) => model.max_token_count(),
128 }
129 }
130
131 pub fn id(&self) -> &str {
132 match self {
133 LanguageModel::OpenAi(model) => model.id(),
134 LanguageModel::Anthropic(model) => model.id(),
135 LanguageModel::Cloud(model) => model.id(),
136 LanguageModel::Ollama(model) => model.id(),
137 }
138 }
139}
140
141#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
142pub struct LanguageModelRequestMessage {
143 pub role: Role,
144 pub content: String,
145}
146
147impl LanguageModelRequestMessage {
148 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
149 proto::LanguageModelRequestMessage {
150 role: match self.role {
151 Role::User => proto::LanguageModelRole::LanguageModelUser,
152 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
153 Role::System => proto::LanguageModelRole::LanguageModelSystem,
154 } as i32,
155 content: self.content.clone(),
156 tool_calls: Vec::new(),
157 tool_call_id: None,
158 }
159 }
160}
161
162#[derive(Debug, Default, Serialize)]
163pub struct LanguageModelRequest {
164 pub model: LanguageModel,
165 pub messages: Vec<LanguageModelRequestMessage>,
166 pub stop: Vec<String>,
167 pub temperature: f32,
168}
169
170impl LanguageModelRequest {
171 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
172 proto::CompleteWithLanguageModel {
173 model: self.model.id().to_string(),
174 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
175 stop: self.stop.clone(),
176 temperature: self.temperature,
177 tool_choice: None,
178 tools: Vec::new(),
179 }
180 }
181
182 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
183 pub fn preprocess(&mut self) {
184 match &self.model {
185 LanguageModel::OpenAi(_) => {}
186 LanguageModel::Anthropic(_) => {}
187 LanguageModel::Ollama(_) => {}
188 LanguageModel::Cloud(model) => match model {
189 CloudModel::Claude3Opus
190 | CloudModel::Claude3Sonnet
191 | CloudModel::Claude3Haiku
192 | CloudModel::Claude3_5Sonnet => {
193 preprocess_anthropic_request(self);
194 }
195 _ => {}
196 },
197 }
198 }
199}
200
201#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
202pub struct LanguageModelResponseMessage {
203 pub role: Option<Role>,
204 pub content: Option<String>,
205}
206
207#[derive(Deserialize, Debug)]
208pub struct LanguageModelUsage {
209 pub prompt_tokens: u32,
210 pub completion_tokens: u32,
211 pub total_tokens: u32,
212}
213
214#[derive(Deserialize, Debug)]
215pub struct LanguageModelChoiceDelta {
216 pub index: u32,
217 pub delta: LanguageModelResponseMessage,
218 pub finish_reason: Option<String>,
219}
220
221#[derive(Clone, Debug, Serialize, Deserialize)]
222struct MessageMetadata {
223 role: Role,
224 status: MessageStatus,
225}
226
227#[derive(Clone, Debug, Serialize, Deserialize)]
228enum MessageStatus {
229 Pending,
230 Done,
231 Error(SharedString),
232}
233
234/// The state pertaining to the Assistant.
235#[derive(Default)]
236struct Assistant {
237 /// Whether the Assistant is enabled.
238 enabled: bool,
239}
240
241impl Global for Assistant {}
242
243impl Assistant {
244 const NAMESPACE: &'static str = "assistant";
245
246 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
247 if self.enabled == enabled {
248 return;
249 }
250
251 self.enabled = enabled;
252
253 if !enabled {
254 CommandPaletteFilter::update_global(cx, |filter, _cx| {
255 filter.hide_namespace(Self::NAMESPACE);
256 });
257
258 return;
259 }
260
261 CommandPaletteFilter::update_global(cx, |filter, _cx| {
262 filter.show_namespace(Self::NAMESPACE);
263 });
264 }
265}
266
267pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
268 cx.set_global(Assistant::default());
269 AssistantSettings::register(cx);
270
271 cx.spawn(|mut cx| {
272 let client = client.clone();
273 async move {
274 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
275 let semantic_index = SemanticIndex::new(
276 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
277 Arc::new(embedding_provider),
278 &mut cx,
279 )
280 .await?;
281 cx.update(|cx| cx.set_global(semantic_index))
282 }
283 })
284 .detach();
285
286 prompt_library::init(cx);
287 completion_provider::init(client.clone(), cx);
288 assistant_slash_command::init(cx);
289 register_slash_commands(cx);
290 assistant_panel::init(cx);
291 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
292 RustdocStore::init_global(cx);
293
294 CommandPaletteFilter::update_global(cx, |filter, _cx| {
295 filter.hide_namespace(Assistant::NAMESPACE);
296 });
297 Assistant::update_global(cx, |assistant, cx| {
298 let settings = AssistantSettings::get_global(cx);
299
300 assistant.set_enabled(settings.enabled, cx);
301 });
302 cx.observe_global::<SettingsStore>(|cx| {
303 Assistant::update_global(cx, |assistant, cx| {
304 let settings = AssistantSettings::get_global(cx);
305 assistant.set_enabled(settings.enabled, cx);
306 });
307 })
308 .detach();
309}
310
311fn register_slash_commands(cx: &mut AppContext) {
312 let slash_command_registry = SlashCommandRegistry::global(cx);
313 slash_command_registry.register_command(file_command::FileSlashCommand, true);
314 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
315 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
316 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
317 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
318 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
319 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
320 slash_command_registry.register_command(term_command::TermSlashCommand, true);
321 slash_command_registry.register_command(now_command::NowSlashCommand, true);
322 slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
323 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
324 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
325}
326
327pub fn humanize_token_count(count: usize) -> String {
328 match count {
329 0..=999 => count.to_string(),
330 1000..=9999 => {
331 let thousands = count / 1000;
332 let hundreds = (count % 1000 + 50) / 100;
333 if hundreds == 0 {
334 format!("{}k", thousands)
335 } else if hundreds == 10 {
336 format!("{}k", thousands + 1)
337 } else {
338 format!("{}.{}k", thousands, hundreds)
339 }
340 }
341 _ => format!("{}k", (count + 500) / 1000),
342 }
343}
344
345#[cfg(test)]
346#[ctor::ctor]
347fn init_logger() {
348 if std::env::var("RUST_LOG").is_ok() {
349 env_logger::init();
350 }
351}