1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type bedrockClient struct {
15 providerOptions providerClientOptions
16 childProvider ProviderClient
17}
18
19type BedrockClient ProviderClient
20
21func newBedrockClient(opts providerClientOptions) BedrockClient {
22 // Get AWS region from environment
23 region := opts.extraParams["region"]
24 if region == "" {
25 region = "us-east-1" // default region
26 }
27 if len(region) < 2 {
28 return &bedrockClient{
29 providerOptions: opts,
30 childProvider: nil, // Will cause an error when used
31 }
32 }
33
34 opts.model = func(modelType config.ModelType) config.Model {
35 model := config.GetModel(modelType)
36
37 // Prefix the model name with region
38 regionPrefix := region[:2]
39 modelName := model.ID
40 model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
41 return model
42 }
43
44 model := opts.model(opts.modelType)
45
46 // Determine which provider to use based on the model
47 if strings.Contains(string(model.ID), "anthropic") {
48 // Create Anthropic client with Bedrock configuration
49 anthropicOpts := opts
50 // TODO: later find a way to check if the AWS account has caching enabled
51 opts.disableCache = true // Disable cache for Bedrock
52 return &bedrockClient{
53 providerOptions: opts,
54 childProvider: newAnthropicClient(anthropicOpts, true),
55 }
56 }
57
58 // Return client with nil childProvider if model is not supported
59 // This will cause an error when used
60 return &bedrockClient{
61 providerOptions: opts,
62 childProvider: nil,
63 }
64}
65
66func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
67 if b.childProvider == nil {
68 return nil, errors.New("unsupported model for bedrock provider")
69 }
70 return b.childProvider.send(ctx, messages, tools)
71}
72
73func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
74 eventChan := make(chan ProviderEvent)
75
76 if b.childProvider == nil {
77 go func() {
78 eventChan <- ProviderEvent{
79 Type: EventError,
80 Error: errors.New("unsupported model for bedrock provider"),
81 }
82 close(eventChan)
83 }()
84 return eventChan
85 }
86
87 return b.childProvider.stream(ctx, messages, tools)
88}
89
90func (b *bedrockClient) Model() config.Model {
91 return b.providerOptions.model(b.providerOptions.modelType)
92}