diff --git a/bedrock/README.md b/bedrock/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa275b175ba322cad6e9212e35615d8ae1415675 --- /dev/null +++ b/bedrock/README.md @@ -0,0 +1,3 @@ +# Bedrock + +* Create an API key [on this page](https://eu-north-1.console.aws.amazon.com/bedrock/home#/api-keys/long-term/create). diff --git a/bedrock/bedrock.go b/bedrock/bedrock.go new file mode 100644 index 0000000000000000000000000000000000000000..e44878994e43c1f04257b3603d00573d999979bf --- /dev/null +++ b/bedrock/bedrock.go @@ -0,0 +1,125 @@ +package bedrock + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/charmbracelet/fantasy/ai" +) + +const ( + Name = "bedrock" + // DefaultURL = "https://bedrock-runtime.amazonaws.com" +) + +type options struct { + name string + httpClient bedrockruntime.HTTPClient + + // region string + // accessKey string + // secretKey string + // sessionToken string + // endpoint string +} + +type provider struct { + options options + client *bedrockruntime.Client +} + +type Option = func(*options) + +func New(ctx context.Context, opts ...Option) (ai.Provider, error) { + providerOptions := options{ + name: Name, + } + for _, o := range opts { + o(&providerOptions) + } + + cfg, err := config.LoadDefaultConfig(ctx) //, config.WithRegion(providerOptions.region)) + if err != nil { + return nil, fmt.Errorf("fantasy: unable to load default aws config: %w", err) + } + + // if providerOptions.accessKey != "" && providerOptions.secretKey != "" { + // cfg.Credentials = aws.CredentialsProviderFunc( + // func(ctx context.Context) (aws.Credentials, error) { + // return aws.Credentials{ + // AccessKeyID: providerOptions.accessKey, + // SecretAccessKey: providerOptions.secretKey, + // SessionToken: providerOptions.sessionToken, + // }, nil + // }, + // ) + // } + + client := bedrockruntime.NewFromConfig( + cfg, + func(o *bedrockruntime.Options) { + if providerOptions.httpClient != nil { + o.HTTPClient = providerOptions.httpClient + } + }, + ) + + return &provider{ + options: providerOptions, + client: client, + }, nil +} + +func WithName(name string) Option { + return func(o *options) { + o.name = name + } +} + +func WithHTTPClient(httpClient bedrockruntime.HTTPClient) Option { + return func(o *options) { + o.httpClient = httpClient + } +} + +// func WithRegion(region string) Option { +// return func(o *options) { +// o.region = region +// } +// } + +// func WithCredentials(accessKey, secretKey, sessionToken string) Option { +// return func(o *options) { +// o.accessKey = accessKey +// o.secretKey = secretKey +// o.sessionToken = sessionToken +// } +// } + +// func WithEndpoint(endpoint string) Option { +// return func(o *options) { +// o.endpoint = endpoint +// } +// } + +func (b *provider) Name() string { + return Name +} + +func (b *provider) LanguageModel(modelID string) (ai.LanguageModel, error) { + return languageModel{ + modelID: modelID, + provider: b.options.name, + client: b.client, + }, nil +} + +func (b *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) { + var options ProviderOptions + if err := ai.ParseOptions(data, &options); err != nil { + return nil, err + } + return &options, nil +} diff --git a/bedrock/claude.go b/bedrock/claude.go new file mode 100644 index 0000000000000000000000000000000000000000..49ff4b104246c64ff775b1dc5f56a8835e8773c5 --- /dev/null +++ b/bedrock/claude.go @@ -0,0 +1,54 @@ +package bedrock + +import ( + "fmt" + + "github.com/charmbracelet/fantasy/ai" +) + +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + +type claudeInput struct { + AnthropicVersion string `json:"anthropic_version"` + MaxTokens *int64 `json:"max_tokens"` + Messages []claudeMessage `json:"messages"` +} + +type claudeMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func toClaudeInput(call ai.Call) (claudeInput, error) { + var err error + i := claudeInput{ + MaxTokens: call.MaxOutputTokens, + } + + i.Messages, err = toClaudePrompt(call.Prompt) + if err != nil { + return i, err + } + return i, nil +} + +func toClaudePrompt(prompt ai.Prompt) (messages []claudeMessage, err error) { + for _, m := range prompt { + message := claudeMessage{ + Role: string(m.Role), + } + + for _, part := range m.Content { + switch content := part.(type) { + case ai.TextPart: + message.Content = content.Text + default: + return nil, fmt.Errorf("fantasy: ") + } + } + + messages = append(messages, message) + } + return messages, err +} diff --git a/bedrock/language_model.go b/bedrock/language_model.go new file mode 100644 index 0000000000000000000000000000000000000000..f5576a43e7f8b498f52e00184e2a34042b50690f --- /dev/null +++ b/bedrock/language_model.go @@ -0,0 +1,85 @@ +package bedrock + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/charmbracelet/fantasy/ai" +) + +type languageModel struct { + provider string + modelID string + client *bedrockruntime.Client +} + +func (b languageModel) Model() string { + return b.modelID +} + +func (b languageModel) Provider() string { + return b.provider +} + +func (b languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) { + params, err := b.prepareParams(call) + if err != nil { + return nil, err + } + + output, err := b.client.InvokeModel(ctx, params) + if err != nil { + return nil, err + } + + panic(fmt.Sprintf("bedrock output: %+v", output)) + + // return &ai.Response{ + // Content: content, + // // Usage: ai.Usage{ + // // InputTokens: response.Usage.InputTokens, + // // OutputTokens: response.Usage.OutputTokens, + // // TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + // // CacheCreationTokens: response.Usage.CacheCreationInputTokens, + // // CacheReadTokens: response.Usage.CacheReadInputTokens, + // // }, + // // FinishReason: mapFinishReason(string(response.StopReason)), + // ProviderMetadata: ai.ProviderMetadata{}, + // // Warnings: warnings, + // }, nil +} + +func (b languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) { + return nil, errors.New("bedrock provider not fully implemented") +} + +func (b languageModel) prepareParams(call ai.Call) (*bedrockruntime.InvokeModelInput, error) { + input := bedrockruntime.InvokeModelInput{ + // ModelId: ptr(fmt.Sprintf("us-east-1.%s", b.modelID)), + ModelId: ptr(b.modelID), + // ModelId: ptr("us-east-1.anthropic.claude-sonnet-4-5-v2:0"), + ContentType: ptr("application/json"), + // Body: body, + } + + // call.Prompt + + switch { + case containsAny(b.modelID, "anthropic", "claude", "sonnet"): + i, err := toClaudeInput(call) + if err != nil { + return nil, err + } + input.Body, err = json.Marshal(i) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("fantasy: bedrock provider does not support model: %s", b.modelID) + } + + return &input, nil +} diff --git a/bedrock/misc.go b/bedrock/misc.go new file mode 100644 index 0000000000000000000000000000000000000000..5557caee9fec617122bbdb6b62dcd9ad642fe16b --- /dev/null +++ b/bedrock/misc.go @@ -0,0 +1,18 @@ +package bedrock + +import ( + "strings" +) + +func ptr[T any](v T) *T { + return &v +} + +func containsAny(str string, options ...string) bool { + for _, option := range options { + if strings.Contains(str, option) { + return true + } + } + return false +} diff --git a/bedrock/provider_options.go b/bedrock/provider_options.go new file mode 100644 index 0000000000000000000000000000000000000000..315bf17aca76a0be450bb80f8df9d2c4b6ccd5a8 --- /dev/null +++ b/bedrock/provider_options.go @@ -0,0 +1,15 @@ +package bedrock + +import "github.com/charmbracelet/fantasy/ai" + +type ProviderOptions struct { + // Add Bedrock-specific options here +} + +func (o *ProviderOptions) Options() {} + +func NewProviderOptions(opts *ProviderOptions) ai.ProviderOptions { + return ai.ProviderOptions{ + Name: opts, + } +} diff --git a/go.mod b/go.mod index f8b013a0b6368491ea09a45fe9a85c5eed723896..a573a1cead1fc474311230b67f0d8a46ac52e08d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.24.5 require ( cloud.google.com/go/auth v0.9.3 github.com/anthropics/anthropic-sdk-go v1.10.0 + github.com/aws/aws-sdk-go-v2/config v1.27.27 + github.com/aws/aws-sdk-go-v2/service/bedrock v1.47.2 github.com/charmbracelet/x/exp/slice v0.0.0-20250904123553-b4e2667e5ad5 github.com/charmbracelet/x/json v0.2.0 github.com/go-viper/mapstructure/v2 v2.4.0 @@ -24,6 +26,20 @@ require ( cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.39.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect diff --git a/go.sum b/go.sum index 0395e7de5c3080227f03a6b3003fe0a112b14225..65f8cd7ae11d5365b248bd127daa6509b0175f1c 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,38 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkY github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I= +github.com/aws/aws-sdk-go-v2 v1.39.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00= +github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= +github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9 h1:se2vOWGD3dWQUtfn4wEjRQJb1HK1XsNIt825gskZ970= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.9/go.mod h1:hijCGH2VfbZQxqCDN7bwz/4dzxV+hkyhjawAtdPWKZA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 h1:6RBnKZLkJM4hQ+kN6E7yWFveOTg8NLPHAkqrs4ZPlTU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9/go.mod h1:V9rQKRmK7AWuEsOMnHzKj8WyrIir1yUJbZxDuZLFvXI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrock v1.47.2 h1:5Bq66lHNfiPHM9WBwzQfhqqctRTxXF3+Un1bm9ZyThE= +github.com/aws/aws-sdk-go-v2/service/bedrock v1.47.2/go.mod h1:3sUHFSHdoib4v7JdqEGgxD2sIdTDikr4IpjBOgUAa0g= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.0 h1:xdYdX+JpIFByMG8JQe9iWM9CqepyjhenukxTVQnuCbM= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.41.0/go.mod h1:c1Ik+59wgLIJFhsSY8cAnw6QooiogpTZKP0rtkVcpCQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251010172108-7b952cdeeb9d h1:qP7F7r7aVY7AReYHHgkQ79weuUEZK7zXtDtSEydYb0w= github.com/charmbracelet/anthropic-sdk-go v0.0.0-20251010172108-7b952cdeeb9d/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= diff --git a/providertests/bedrock_test.go b/providertests/bedrock_test.go new file mode 100644 index 0000000000000000000000000000000000000000..650152c3b46f34f2cdbd2228fecd60e56921d290 --- /dev/null +++ b/providertests/bedrock_test.go @@ -0,0 +1,47 @@ +package providertests + +import ( + "net/http" + "testing" + + "github.com/charmbracelet/fantasy/ai" + "github.com/charmbracelet/fantasy/bedrock" + "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" +) + +// const defaultBaseURL = "https://fantasy-playground-resource.services.ai.azure.com/" + +func TestBedrockCommon(t *testing.T) { + testCommon(t, []builderPair{ + {"bedrock-anthropic-claude-v2", builderBedrockClaudeV2(t), nil}, + }) +} + +// func TestBedrockThinking(t *testing.T) { +// opts := ai.ProviderOptions{ +// bedrock.Name: &bedrock.ProviderOptions{ +// ReasoningEffort: openai.ReasoningEffortOption(openai.ReasoningEffortLow), +// }, +// } +// testThinking(t, []builderPair{ +// {"bedrock-anthropic-claude-v2", builderBedrockClaudeV2(t), opts}, +// }, testBedrockThinking) +// } + +// func testBedrockThinking(t *testing.T, result *ai.AgentResult) { +// require.Greater(t, result.Response.Usage.ReasoningTokens, int64(0), "expected reasoning tokens, got none") +// } + +func builderBedrockClaudeV2(t *testing.T) func(r *recorder.Recorder) (ai.LanguageModel, error) { + return func(r *recorder.Recorder) (ai.LanguageModel, error) { + provider, err := bedrock.New( + t.Context(), + bedrock.WithHTTPClient(&http.Client{Transport: r}), + ) + if err != nil { + return nil, err + } + // return provider.LanguageModel("anthropic.claude-sonnet-4-5-20250929-v1:0") + return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0") + } +}