1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "strings"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type bedrockClient struct {
15 providerOptions providerClientOptions
16 childProvider ProviderClient
17}
18
19type BedrockClient ProviderClient
20
21func newBedrockClient(opts providerClientOptions) BedrockClient {
22 // Apply bedrock specific options if they are added in the future
23
24 // Get AWS region from environment
25 region := os.Getenv("AWS_REGION")
26 if region == "" {
27 region = os.Getenv("AWS_DEFAULT_REGION")
28 }
29
30 if region == "" {
31 region = "us-east-1" // default region
32 }
33 if len(region) < 2 {
34 return &bedrockClient{
35 providerOptions: opts,
36 childProvider: nil, // Will cause an error when used
37 }
38 }
39
40 // Prefix the model name with region
41 regionPrefix := region[:2]
42 modelName := opts.model.APIModel
43 opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
44
45 // Determine which provider to use based on the model
46 if strings.Contains(string(opts.model.APIModel), "anthropic") {
47 // Create Anthropic client with Bedrock configuration
48 anthropicOpts := opts
49 // TODO: later find a way to check if the AWS account has caching enabled
50 opts.disableCache = true // Disable cache for Bedrock
51 return &bedrockClient{
52 providerOptions: opts,
53 childProvider: newAnthropicClient(anthropicOpts, true),
54 }
55 }
56
57 // Return client with nil childProvider if model is not supported
58 // This will cause an error when used
59 return &bedrockClient{
60 providerOptions: opts,
61 childProvider: nil,
62 }
63}
64
65func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
66 if b.childProvider == nil {
67 return nil, errors.New("unsupported model for bedrock provider")
68 }
69 return b.childProvider.send(ctx, messages, tools)
70}
71
72func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
73 eventChan := make(chan ProviderEvent)
74
75 if b.childProvider == nil {
76 go func() {
77 eventChan <- ProviderEvent{
78 Type: EventError,
79 Error: errors.New("unsupported model for bedrock provider"),
80 }
81 close(eventChan)
82 }()
83 return eventChan
84 }
85
86 return b.childProvider.stream(ctx, messages, tools)
87}