1mod model;
2pub mod provider;
3mod registry;
4mod request;
5mod role;
6pub mod settings;
7
8use std::sync::Arc;
9
10use anyhow::Result;
11use client::Client;
12use futures::{future::BoxFuture, stream::BoxStream};
13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
14
15pub use model::*;
16use project::Fs;
17pub use registry::*;
18pub use request::*;
19pub use role::*;
20use schemars::JsonSchema;
21use serde::de::DeserializeOwned;
22
23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
24 settings::init(fs, cx);
25 registry::init(client, cx);
26}
27
28pub trait LanguageModel: Send + Sync {
29 fn id(&self) -> LanguageModelId;
30 fn name(&self) -> LanguageModelName;
31 fn provider_id(&self) -> LanguageModelProviderId;
32 fn provider_name(&self) -> LanguageModelProviderName;
33 fn telemetry_id(&self) -> String;
34
35 fn max_token_count(&self) -> usize;
36
37 fn count_tokens(
38 &self,
39 request: LanguageModelRequest,
40 cx: &AppContext,
41 ) -> BoxFuture<'static, Result<usize>>;
42
43 fn stream_completion(
44 &self,
45 request: LanguageModelRequest,
46 cx: &AsyncAppContext,
47 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
48
49 fn use_tool(
50 &self,
51 request: LanguageModelRequest,
52 name: String,
53 description: String,
54 schema: serde_json::Value,
55 cx: &AsyncAppContext,
56 ) -> BoxFuture<'static, Result<serde_json::Value>>;
57}
58
59pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
60 fn name() -> String;
61 fn description() -> String;
62}
63
64pub trait LanguageModelProvider: 'static {
65 fn id(&self) -> LanguageModelProviderId;
66 fn name(&self) -> LanguageModelProviderName;
67 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
68 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
69 fn is_authenticated(&self, cx: &AppContext) -> bool;
70 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
71 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
72 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
73}
74
75pub trait LanguageModelProviderState: 'static {
76 fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
77}
78
79#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
80pub struct LanguageModelId(pub SharedString);
81
82#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
83pub struct LanguageModelName(pub SharedString);
84
85#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
86pub struct LanguageModelProviderId(pub SharedString);
87
88#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
89pub struct LanguageModelProviderName(pub SharedString);
90
91impl From<String> for LanguageModelId {
92 fn from(value: String) -> Self {
93 Self(SharedString::from(value))
94 }
95}
96
97impl From<String> for LanguageModelName {
98 fn from(value: String) -> Self {
99 Self(SharedString::from(value))
100 }
101}
102
103impl From<String> for LanguageModelProviderId {
104 fn from(value: String) -> Self {
105 Self(SharedString::from(value))
106 }
107}
108
109impl From<String> for LanguageModelProviderName {
110 fn from(value: String) -> Self {
111 Self(SharedString::from(value))
112 }
113}