1mod ambient_context;
2pub mod assistant_panel;
3pub mod assistant_settings;
4mod codegen;
5mod completion_provider;
6mod prompt_library;
7mod prompts;
8mod saved_conversation;
9mod streaming_diff;
10
11pub use assistant_panel::AssistantPanel;
12use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
13use client::{proto, Client};
14use command_palette_hooks::CommandPaletteFilter;
15pub(crate) use completion_provider::*;
16use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
17pub(crate) use saved_conversation::*;
18use serde::{Deserialize, Serialize};
19use settings::{Settings, SettingsStore};
20use std::{
21 fmt::{self, Display},
22 sync::Arc,
23};
24
25actions!(
26 assistant,
27 [
28 Assist,
29 Split,
30 CycleMessageRole,
31 QuoteSelection,
32 ToggleFocus,
33 ResetKey,
34 InlineAssist,
35 InsertActivePrompt,
36 ToggleIncludeConversation,
37 ToggleHistory,
38 ]
39);
40
41#[derive(
42 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
43)]
44struct MessageId(usize);
45
46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
47#[serde(rename_all = "lowercase")]
48pub enum Role {
49 User,
50 Assistant,
51 System,
52}
53
54impl Role {
55 pub fn cycle(&mut self) {
56 *self = match self {
57 Role::User => Role::Assistant,
58 Role::Assistant => Role::System,
59 Role::System => Role::User,
60 }
61 }
62}
63
64impl Display for Role {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Role::User => write!(f, "user"),
68 Role::Assistant => write!(f, "assistant"),
69 Role::System => write!(f, "system"),
70 }
71 }
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
75pub enum LanguageModel {
76 ZedDotDev(ZedDotDevModel),
77 OpenAi(OpenAiModel),
78 Anthropic(AnthropicModel),
79}
80
81impl Default for LanguageModel {
82 fn default() -> Self {
83 LanguageModel::ZedDotDev(ZedDotDevModel::default())
84 }
85}
86
87impl LanguageModel {
88 pub fn telemetry_id(&self) -> String {
89 match self {
90 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
91 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
92 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
93 }
94 }
95
96 pub fn display_name(&self) -> String {
97 match self {
98 LanguageModel::OpenAi(model) => model.display_name().into(),
99 LanguageModel::Anthropic(model) => model.display_name().into(),
100 LanguageModel::ZedDotDev(model) => model.display_name().into(),
101 }
102 }
103
104 pub fn max_token_count(&self) -> usize {
105 match self {
106 LanguageModel::OpenAi(model) => model.max_token_count(),
107 LanguageModel::Anthropic(model) => model.max_token_count(),
108 LanguageModel::ZedDotDev(model) => model.max_token_count(),
109 }
110 }
111
112 pub fn id(&self) -> &str {
113 match self {
114 LanguageModel::OpenAi(model) => model.id(),
115 LanguageModel::Anthropic(model) => model.id(),
116 LanguageModel::ZedDotDev(model) => model.id(),
117 }
118 }
119}
120
121#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
122pub struct LanguageModelRequestMessage {
123 pub role: Role,
124 pub content: String,
125}
126
127impl LanguageModelRequestMessage {
128 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
129 proto::LanguageModelRequestMessage {
130 role: match self.role {
131 Role::User => proto::LanguageModelRole::LanguageModelUser,
132 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
133 Role::System => proto::LanguageModelRole::LanguageModelSystem,
134 } as i32,
135 content: self.content.clone(),
136 tool_calls: Vec::new(),
137 tool_call_id: None,
138 }
139 }
140}
141
142#[derive(Debug, Default, Serialize)]
143pub struct LanguageModelRequest {
144 pub model: LanguageModel,
145 pub messages: Vec<LanguageModelRequestMessage>,
146 pub stop: Vec<String>,
147 pub temperature: f32,
148}
149
150impl LanguageModelRequest {
151 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
152 proto::CompleteWithLanguageModel {
153 model: self.model.id().to_string(),
154 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
155 stop: self.stop.clone(),
156 temperature: self.temperature,
157 tool_choice: None,
158 tools: Vec::new(),
159 }
160 }
161}
162
163#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
164pub struct LanguageModelResponseMessage {
165 pub role: Option<Role>,
166 pub content: Option<String>,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct LanguageModelUsage {
171 pub prompt_tokens: u32,
172 pub completion_tokens: u32,
173 pub total_tokens: u32,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct LanguageModelChoiceDelta {
178 pub index: u32,
179 pub delta: LanguageModelResponseMessage,
180 pub finish_reason: Option<String>,
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize)]
184struct MessageMetadata {
185 role: Role,
186 status: MessageStatus,
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190enum MessageStatus {
191 Pending,
192 Done,
193 Error(SharedString),
194}
195
196/// The state pertaining to the Assistant.
197#[derive(Default)]
198struct Assistant {
199 /// Whether the Assistant is enabled.
200 enabled: bool,
201}
202
203impl Global for Assistant {}
204
205impl Assistant {
206 const NAMESPACE: &'static str = "assistant";
207
208 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
209 if self.enabled == enabled {
210 return;
211 }
212
213 self.enabled = enabled;
214
215 if !enabled {
216 CommandPaletteFilter::update_global(cx, |filter, _cx| {
217 filter.hide_namespace(Self::NAMESPACE);
218 });
219
220 return;
221 }
222
223 CommandPaletteFilter::update_global(cx, |filter, _cx| {
224 filter.show_namespace(Self::NAMESPACE);
225 });
226 }
227}
228
229pub fn init(client: Arc<Client>, cx: &mut AppContext) {
230 cx.set_global(Assistant::default());
231 AssistantSettings::register(cx);
232 completion_provider::init(client, cx);
233 assistant_panel::init(cx);
234
235 CommandPaletteFilter::update_global(cx, |filter, _cx| {
236 filter.hide_namespace(Assistant::NAMESPACE);
237 });
238 Assistant::update_global(cx, |assistant, cx| {
239 let settings = AssistantSettings::get_global(cx);
240
241 assistant.set_enabled(settings.enabled, cx);
242 });
243 cx.observe_global::<SettingsStore>(|cx| {
244 Assistant::update_global(cx, |assistant, cx| {
245 let settings = AssistantSettings::get_global(cx);
246
247 assistant.set_enabled(settings.enabled, cx);
248 });
249 })
250 .detach();
251}
252
253#[cfg(test)]
254#[ctor::ctor]
255fn init_logger() {
256 if std::env::var("RUST_LOG").is_ok() {
257 env_logger::init();
258 }
259}