1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/llm/tools"
12 "github.com/charmbracelet/crush/internal/message"
13)
14
15type bedrockClient struct {
16 providerOptions providerClientOptions
17 childProvider ProviderClient
18}
19
20type BedrockClient ProviderClient
21
22func newBedrockClient(opts providerClientOptions) BedrockClient {
23 // Get AWS region from environment
24 region := opts.extraParams["region"]
25 if region == "" {
26 region = "us-east-1" // default region
27 }
28 if len(region) < 2 {
29 return &bedrockClient{
30 providerOptions: opts,
31 childProvider: nil, // Will cause an error when used
32 }
33 }
34
35 opts.model = func(modelType config.SelectedModelType) catwalk.Model {
36 model := config.Get().GetModelByType(modelType)
37
38 // Prefix the model name with region
39 regionPrefix := region[:2]
40 modelName := model.ID
41 model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
42 return *model
43 }
44
45 model := opts.model(opts.modelType)
46
47 // Determine which provider to use based on the model
48 if strings.Contains(string(model.ID), "anthropic") {
49 // Create Anthropic client with Bedrock configuration
50 anthropicOpts := opts
51 // TODO: later find a way to check if the AWS account has caching enabled
52 opts.disableCache = true // Disable cache for Bedrock
53 return &bedrockClient{
54 providerOptions: opts,
55 childProvider: newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock),
56 }
57 }
58
59 // Return client with nil childProvider if model is not supported
60 // This will cause an error when used
61 return &bedrockClient{
62 providerOptions: opts,
63 childProvider: nil,
64 }
65}
66
67func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
68 if b.childProvider == nil {
69 return nil, errors.New("unsupported model for bedrock provider")
70 }
71 return b.childProvider.send(ctx, messages, tools)
72}
73
74func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
75 eventChan := make(chan ProviderEvent)
76
77 if b.childProvider == nil {
78 go func() {
79 eventChan <- ProviderEvent{
80 Type: EventError,
81 Error: errors.New("unsupported model for bedrock provider"),
82 }
83 close(eventChan)
84 }()
85 return eventChan
86 }
87
88 return b.childProvider.stream(ctx, messages, tools)
89}
90
91func (b *bedrockClient) Model() catwalk.Model {
92 return b.providerOptions.model(b.providerOptions.modelType)
93}