bedrock.go

  1package provider
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"strings"
  9
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14type bedrockOptions struct {
 15	// Bedrock specific options can be added here
 16}
 17
 18type BedrockOption func(*bedrockOptions)
 19
 20type bedrockClient struct {
 21	providerOptions providerClientOptions
 22	options         bedrockOptions
 23	childProvider   ProviderClient
 24}
 25
 26type BedrockClient ProviderClient
 27
 28func newBedrockClient(opts providerClientOptions) BedrockClient {
 29	bedrockOpts := bedrockOptions{}
 30	// Apply bedrock specific options if they are added in the future
 31
 32	// Get AWS region from environment
 33	region := os.Getenv("AWS_REGION")
 34	if region == "" {
 35		region = os.Getenv("AWS_DEFAULT_REGION")
 36	}
 37
 38	if region == "" {
 39		region = "us-east-1" // default region
 40	}
 41	if len(region) < 2 {
 42		return &bedrockClient{
 43			providerOptions: opts,
 44			options:         bedrockOpts,
 45			childProvider:   nil, // Will cause an error when used
 46		}
 47	}
 48
 49	// Prefix the model name with region
 50	regionPrefix := region[:2]
 51	modelName := opts.model.APIModel
 52	opts.model.APIModel = fmt.Sprintf("%s.%s", regionPrefix, modelName)
 53
 54	// Determine which provider to use based on the model
 55	if strings.Contains(string(opts.model.APIModel), "anthropic") {
 56		// Create Anthropic client with Bedrock configuration
 57		anthropicOpts := opts
 58		anthropicOpts.anthropicOptions = append(anthropicOpts.anthropicOptions,
 59			WithAnthropicBedrock(true),
 60			WithAnthropicDisableCache(),
 61		)
 62		return &bedrockClient{
 63			providerOptions: opts,
 64			options:         bedrockOpts,
 65			childProvider:   newAnthropicClient(anthropicOpts),
 66		}
 67	}
 68
 69	// Return client with nil childProvider if model is not supported
 70	// This will cause an error when used
 71	return &bedrockClient{
 72		providerOptions: opts,
 73		options:         bedrockOpts,
 74		childProvider:   nil,
 75	}
 76}
 77
 78func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 79	if b.childProvider == nil {
 80		return nil, errors.New("unsupported model for bedrock provider")
 81	}
 82	return b.childProvider.send(ctx, messages, tools)
 83}
 84
 85func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
 86	eventChan := make(chan ProviderEvent)
 87
 88	if b.childProvider == nil {
 89		go func() {
 90			eventChan <- ProviderEvent{
 91				Type:  EventError,
 92				Error: errors.New("unsupported model for bedrock provider"),
 93			}
 94			close(eventChan)
 95		}()
 96		return eventChan
 97	}
 98
 99	return b.childProvider.stream(ctx, messages, tools)
100}