1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod model_selector;
6mod prompt_library;
7mod prompts;
8mod saved_conversation;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
21pub(crate) use model_selector::*;
22use prompt_library::PromptStore;
23pub(crate) use saved_conversation::*;
24use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, SettingsStore};
27use slash_command::{
28 active_command, file_command, project_command, prompt_command, rustdoc_command, search_command,
29 tabs_command,
30};
31use std::{
32 fmt::{self, Display},
33 sync::Arc,
34};
35use util::paths::EMBEDDINGS_DIR;
36
37actions!(
38 assistant,
39 [
40 Assist,
41 Split,
42 CycleMessageRole,
43 QuoteSelection,
44 ToggleFocus,
45 ResetKey,
46 InlineAssist,
47 InsertActivePrompt,
48 ToggleHistory,
49 ApplyEdit,
50 ConfirmCommand,
51 ToggleModelSelector
52 ]
53);
54
55#[derive(
56 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
57)]
58struct MessageId(usize);
59
60#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
61#[serde(rename_all = "lowercase")]
62pub enum Role {
63 User,
64 Assistant,
65 System,
66}
67
68impl Role {
69 pub fn cycle(&mut self) {
70 *self = match self {
71 Role::User => Role::Assistant,
72 Role::Assistant => Role::System,
73 Role::System => Role::User,
74 }
75 }
76}
77
78impl Display for Role {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Role::User => write!(f, "user"),
82 Role::Assistant => write!(f, "assistant"),
83 Role::System => write!(f, "system"),
84 }
85 }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
89pub enum LanguageModel {
90 Cloud(CloudModel),
91 OpenAi(OpenAiModel),
92 Anthropic(AnthropicModel),
93}
94
95impl Default for LanguageModel {
96 fn default() -> Self {
97 LanguageModel::Cloud(CloudModel::default())
98 }
99}
100
101impl LanguageModel {
102 pub fn telemetry_id(&self) -> String {
103 match self {
104 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
105 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
106 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
107 }
108 }
109
110 pub fn display_name(&self) -> String {
111 match self {
112 LanguageModel::OpenAi(model) => model.display_name().into(),
113 LanguageModel::Anthropic(model) => model.display_name().into(),
114 LanguageModel::Cloud(model) => model.display_name().into(),
115 }
116 }
117
118 pub fn max_token_count(&self) -> usize {
119 match self {
120 LanguageModel::OpenAi(model) => model.max_token_count(),
121 LanguageModel::Anthropic(model) => model.max_token_count(),
122 LanguageModel::Cloud(model) => model.max_token_count(),
123 }
124 }
125
126 pub fn id(&self) -> &str {
127 match self {
128 LanguageModel::OpenAi(model) => model.id(),
129 LanguageModel::Anthropic(model) => model.id(),
130 LanguageModel::Cloud(model) => model.id(),
131 }
132 }
133}
134
135#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
136pub struct LanguageModelRequestMessage {
137 pub role: Role,
138 pub content: String,
139}
140
141impl LanguageModelRequestMessage {
142 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
143 proto::LanguageModelRequestMessage {
144 role: match self.role {
145 Role::User => proto::LanguageModelRole::LanguageModelUser,
146 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
147 Role::System => proto::LanguageModelRole::LanguageModelSystem,
148 } as i32,
149 content: self.content.clone(),
150 tool_calls: Vec::new(),
151 tool_call_id: None,
152 }
153 }
154}
155
156#[derive(Debug, Default, Serialize)]
157pub struct LanguageModelRequest {
158 pub model: LanguageModel,
159 pub messages: Vec<LanguageModelRequestMessage>,
160 pub stop: Vec<String>,
161 pub temperature: f32,
162}
163
164impl LanguageModelRequest {
165 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
166 proto::CompleteWithLanguageModel {
167 model: self.model.id().to_string(),
168 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
169 stop: self.stop.clone(),
170 temperature: self.temperature,
171 tool_choice: None,
172 tools: Vec::new(),
173 }
174 }
175
176 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
177 pub fn preprocess(&mut self) {
178 match &self.model {
179 LanguageModel::OpenAi(_) => {}
180 LanguageModel::Anthropic(_) => {}
181 LanguageModel::Cloud(model) => match model {
182 CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
183 preprocess_anthropic_request(self);
184 }
185 _ => {}
186 },
187 }
188 }
189}
190
191#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
192pub struct LanguageModelResponseMessage {
193 pub role: Option<Role>,
194 pub content: Option<String>,
195}
196
197#[derive(Deserialize, Debug)]
198pub struct LanguageModelUsage {
199 pub prompt_tokens: u32,
200 pub completion_tokens: u32,
201 pub total_tokens: u32,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct LanguageModelChoiceDelta {
206 pub index: u32,
207 pub delta: LanguageModelResponseMessage,
208 pub finish_reason: Option<String>,
209}
210
211#[derive(Clone, Debug, Serialize, Deserialize)]
212struct MessageMetadata {
213 role: Role,
214 status: MessageStatus,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218enum MessageStatus {
219 Pending,
220 Done,
221 Error(SharedString),
222}
223
224/// The state pertaining to the Assistant.
225#[derive(Default)]
226struct Assistant {
227 /// Whether the Assistant is enabled.
228 enabled: bool,
229}
230
231impl Global for Assistant {}
232
233impl Assistant {
234 const NAMESPACE: &'static str = "assistant";
235
236 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
237 if self.enabled == enabled {
238 return;
239 }
240
241 self.enabled = enabled;
242
243 if !enabled {
244 CommandPaletteFilter::update_global(cx, |filter, _cx| {
245 filter.hide_namespace(Self::NAMESPACE);
246 });
247
248 return;
249 }
250
251 CommandPaletteFilter::update_global(cx, |filter, _cx| {
252 filter.show_namespace(Self::NAMESPACE);
253 });
254 }
255}
256
257pub fn init(client: Arc<Client>, cx: &mut AppContext) {
258 cx.set_global(Assistant::default());
259 AssistantSettings::register(cx);
260
261 cx.spawn(|mut cx| {
262 let client = client.clone();
263 async move {
264 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
265 let semantic_index = SemanticIndex::new(
266 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
267 Arc::new(embedding_provider),
268 &mut cx,
269 )
270 .await?;
271 cx.update(|cx| cx.set_global(semantic_index))
272 }
273 })
274 .detach();
275
276 prompt_library::init(cx);
277 completion_provider::init(client, cx);
278 assistant_slash_command::init(cx);
279 register_slash_commands(cx);
280 assistant_panel::init(cx);
281
282 CommandPaletteFilter::update_global(cx, |filter, _cx| {
283 filter.hide_namespace(Assistant::NAMESPACE);
284 });
285 Assistant::update_global(cx, |assistant, cx| {
286 let settings = AssistantSettings::get_global(cx);
287
288 assistant.set_enabled(settings.enabled, cx);
289 });
290 cx.observe_global::<SettingsStore>(|cx| {
291 Assistant::update_global(cx, |assistant, cx| {
292 let settings = AssistantSettings::get_global(cx);
293 assistant.set_enabled(settings.enabled, cx);
294 });
295 })
296 .detach();
297}
298
299fn register_slash_commands(cx: &mut AppContext) {
300 let slash_command_registry = SlashCommandRegistry::global(cx);
301 slash_command_registry.register_command(file_command::FileSlashCommand, true);
302 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
303 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
304 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
305 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
306 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
307
308 let store = PromptStore::global(cx);
309 cx.background_executor()
310 .spawn(async move {
311 let store = store.await?;
312 slash_command_registry
313 .register_command(prompt_command::PromptSlashCommand::new(store), true);
314 anyhow::Ok(())
315 })
316 .detach_and_log_err(cx);
317}
318
319#[cfg(test)]
320#[ctor::ctor]
321fn init_logger() {
322 if std::env::var("RUST_LOG").is_ok() {
323 env_logger::init();
324 }
325}