bedrock.go

 1package provider
 2
 3import (
 4	"context"
 5	"errors"
 6	"fmt"
 7	"os"
 8	"strings"
 9
10	"github.com/charmbracelet/crush/internal/llm/tools"
11	"github.com/charmbracelet/crush/internal/message"
12)
13
14type bedrockClient struct {
15	providerOptions providerClientOptions
16	childProvider   ProviderClient
17}
18
19type BedrockClient ProviderClient
20
21func newBedrockClient(opts providerClientOptions) BedrockClient {
22	// Apply bedrock specific options if they are added in the future
23
24	// Get AWS region from environment
25	region := os.Getenv("AWS_REGION")
26	if region == "" {
27		region = os.Getenv("AWS_DEFAULT_REGION")
28	}
29
30	if region == "" {
31		region = "us-east-1" // default region
32	}
33	if len(region) < 2 {
34		return &bedrockClient{
35			providerOptions: opts,
36			childProvider:   nil, // Will cause an error when used
37		}
38	}
39
40	// Prefix the model name with region
41	regionPrefix := region[:2]
42	modelName := opts.model.APIModel
43	opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
44
45	// Determine which provider to use based on the model
46	if strings.Contains(string(opts.model.APIModel), "anthropic") {
47		// Create Anthropic client with Bedrock configuration
48		anthropicOpts := opts
49		// TODO: later find a way to check if the AWS account has caching enabled
50		opts.disableCache = true // Disable cache for Bedrock
51		return &bedrockClient{
52			providerOptions: opts,
53			childProvider:   newAnthropicClient(anthropicOpts, true),
54		}
55	}
56
57	// Return client with nil childProvider if model is not supported
58	// This will cause an error when used
59	return &bedrockClient{
60		providerOptions: opts,
61		childProvider:   nil,
62	}
63}
64
65func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
66	if b.childProvider == nil {
67		return nil, errors.New("unsupported model for bedrock provider")
68	}
69	return b.childProvider.send(ctx, messages, tools)
70}
71
72func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
73	eventChan := make(chan ProviderEvent)
74
75	if b.childProvider == nil {
76		go func() {
77			eventChan <- ProviderEvent{
78				Type:  EventError,
79				Error: errors.New("unsupported model for bedrock provider"),
80			}
81			close(eventChan)
82		}()
83		return eventChan
84	}
85
86	return b.childProvider.stream(ctx, messages, tools)
87}