Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Commit

Permalink
gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
clement2026 committed Dec 22, 2023
1 parent f9f37f4 commit 7851934
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 41 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/caddyserver/certmagic v0.19.2
github.com/dustin/go-humanize v1.0.1
github.com/go-playground/validator/v10 v10.16.0
github.com/google/generative-ai-go v0.5.0
github.com/google/uuid v1.4.0
github.com/haguro/elevenlabs-go v0.2.2
github.com/labstack/echo/v4 v4.11.3
Expand All @@ -31,6 +32,7 @@ require (

require (
cloud.google.com/go v0.111.0 // indirect
cloud.google.com/go/ai v0.3.0 // indirect
cloud.google.com/go/compute v1.23.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.5 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.111.0 h1:YHLKNupSD1KqjDbQ3+LVdQ81h/UJbJyZG203cEfnQgM=
cloud.google.com/go v0.111.0/go.mod h1:0mibmpKP1TyOOFYQY5izo0LnT+ecvOQ0Sg3OdmMiNRU=
cloud.google.com/go/ai v0.3.0 h1:M617N0brv+XFch2KToZUhv6ggzgFZMUnmDkNQjW2pYg=
cloud.google.com/go/ai v0.3.0/go.mod h1:dTuQIBA8Kljuas5z1WNot1QZOl476A9TsFqEi6pzJlI=
cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk=
cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
Expand Down Expand Up @@ -68,6 +70,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/generative-ai-go v0.5.0 h1:PfzPuSGdsmcSyPG7RIoijcKWZ7/x2kvgyNryvmXMUmA=
github.com/google/generative-ai-go v0.5.0/go.mod h1:8fXQk4w+eyTzFokGGJrBFL0/xwXqm3QNhTqOWyX11zs=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
Expand Down
5 changes: 2 additions & 3 deletions internal/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (c *ChatHandler) Start(ms []client.Message, ar *AudioReader) {
if c.o.Completion {
text, err := c.completion(ctx, ms, client.RoleAssistant)
if err != nil {
c.logger.Sugar().Error("got empty text from completion", err)
c.logger.Sugar().Error("got empty text from completion, ", err)
return
}

Expand Down Expand Up @@ -204,10 +204,9 @@ func (c *ChatHandler) completion(ctx context.Context, latestMs []client.Message,
go func() { c.sse.PublishData(c.streamId, EventMessageThinking, meta) }()

stream := llm.CompletionStream(ctx, latestMs, *c.o.LLMOption)
defer stream.Close()

text := ""
for {

data, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type TextToSpeechConfig struct {

type LlmConfig struct {
ChatGPT string `mapstructure:"chat-gpt"`
Gemini string `mapstructure:"gemini"`
}

type TLSPolicy int
Expand Down
5 changes: 5 additions & 0 deletions internal/talker.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func NewTalker(tc config.TalkConfig, logger *zap.Logger) (*Talker, error) {
llms = append(llms, llm)
}

if apiKey, ok := tc.Creds[tc.Llm.Gemini]; ok {
llm := providers.NewGemini(apiKey, logger)
llms = append(llms, llm)
}

if apiKey, ok := tc.Creds[tc.TextToSpeech.ElevenLabs]; ok {
tts := providers.NewElevenLabs(apiKey, logger)
ttss = append(ttss, tts)
Expand Down
10 changes: 8 additions & 2 deletions pkg/ability/ability.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,17 @@ type GoogleSTTAb struct {
type LLMAblt struct {
Available bool `json:"available"`
ChatGPT ChatGPTAblt `json:"chatGPT"`
Gemini GeminiAblt `json:"gemini"`
}

type ChatGPTAblt struct {
Available bool `json:"available"`
Models []string `json:"models"`
Available bool `json:"available"`
Models []Model `json:"models"`
}

type GeminiAblt struct {
Available bool `json:"available"`
Models []Model `json:"models"`
}

// other
Expand Down
15 changes: 13 additions & 2 deletions pkg/ability/defaults.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
package ability

func DefaultChatGPTOption() *ChatGPTOption {
model := "gpt-3.5-turbo"
return &ChatGPTOption{
Model: model,
Model: "gpt-3.5-turbo",
MaxTokens: 2000,
Temperature: 1,
TopP: 1,
PresencePenalty: 0,
FrequencyPenalty: 0,
}
}

func DefaultGeminiOption() *GeminiOption {
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
return &GeminiOption{
Model: "gemini-pro",
StopSequences: nil,
MaxOutputTokens: 8192,
Temperature: 0.9,
TopP: 1,
TopK: 32,
}
}
10 changes: 10 additions & 0 deletions pkg/ability/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "cloud.google.com/go/texttospeech/apiv1/texttospeechpb"
// LLMOption clients use TalkOption to guide LLMAblt in generating text
type LLMOption struct {
ChatGPT *ChatGPTOption `json:"chatGPT"`
Gemini *GeminiOption `json:"gemini"`
}

type ChatGPTOption struct {
Expand All @@ -16,6 +17,15 @@ type ChatGPTOption struct {
FrequencyPenalty float32 `json:"frequencyPenalty"`
}

type GeminiOption struct {
Model string `json:"model"`
StopSequences []string `json:"stopSequences"`
MaxOutputTokens int32 `json:"maxOutputTokens"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"topP"`
TopK int32 `json:"topK"`
}

type STTOption struct {
Whisper *WhisperOption `json:"whisper"`
Google *GoogleSTTOption `json:"google"`
Expand Down
6 changes: 6 additions & 0 deletions pkg/ability/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package ability

type Model struct {
Name string `json:"name" validate:"required"`
DisplayName string `json:"displayName" validate:"required"`
}
24 changes: 21 additions & 3 deletions pkg/client/types.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package client

import (
"fmt"
)

type Message struct {
Role Role `json:"role" validate:"required"` // options: system, user, assistant and function
Content string `json:"content" validate:"required"`
Expand All @@ -8,7 +12,21 @@ type Message struct {
type Role string

const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleSystem Role = "system"
RoleUser Role = "user"

RoleAssistant Role = "assistant" // for ChatGPT
RoleSystem Role = "system" // for ChatGPT

RoleModel Role = "model" // for Gemini
)

func (r Role) ToGeminiRole() (Role, error) {
switch r {
case RoleUser:
return RoleUser, nil
case RoleAssistant:
return RoleModel, nil
default:
return "", fmt.Errorf("role %s is invalid", r)
}
}
13 changes: 8 additions & 5 deletions pkg/providers/chatgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ func (c *chatGPT) Completion(ctx context.Context, ms []client.Message, t ability
// CompletionStream
//
// Return only one chunk that contains the whole content if stream is not supported.
// To make sure the chan closes eventually, caller should either read the last chunk from chan
// or got a chunk whose Err != nil
func (c *chatGPT) CompletionStream(ctx context.Context, ms []client.Message, t ability.LLMOption) *util.SmoothStream {
c.logger.Sugar().Debugw("completion stream...", "message list length", len(ms))
stream := util.NewSmoothStream()
Expand All @@ -104,7 +102,7 @@ func (c *chatGPT) CompletionStream(ctx context.Context, ms []client.Message, t a
}
reqLog := req
reqLog.Messages = nil
c.logger.Sugar().Debugw("completion stream req without messages:", reqLog)
c.logger.Sugar().Debug("completion stream req without messages:", reqLog)

go func() {
s, err := c.client.CreateChatCompletionStream(ctx, req)
Expand Down Expand Up @@ -148,7 +146,7 @@ func (c *chatGPT) Support(o ability.LLMOption) bool {
return o.ChatGPT != nil
}

func (c *chatGPT) getModels(ctx context.Context) ([]string, error) {
func (c *chatGPT) getModels(ctx context.Context) ([]ability.Model, error) {
c.logger.Info("get models...")
ml, err := c.client.ListModels(ctx)
if err != nil {
Expand All @@ -162,7 +160,12 @@ func (c *chatGPT) getModels(ctx context.Context) ([]string, error) {
}
sort.Strings(models)
c.logger.Sugar().Debug("models count:", len(models))
return models, err
ms := make([]ability.Model, len(models))
for i, model := range models {
ms[i].Name = model
ms[i].DisplayName = model
}
return ms, err
}

func messageOfComplete(ms []client.Message) []openai.ChatCompletionMessage {
Expand Down
39 changes: 19 additions & 20 deletions pkg/providers/chatgpt_demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package providers
import (
"context"
"errors"
"io"
"math/rand"
"strings"
"time"
Expand All @@ -14,22 +15,22 @@ import (
"go.uber.org/zap"
)

var chatGPTDemoModels = []string{
"gpt-4-32k-0613[demo]",
"gpt-4-32k-0314[demo]",
"gpt-4-32k[demo]",
"gpt-4-0613[demo]",
"gpt-4-0314[demo]",
"gpt-4-1106-preview[demo]",
"gpt-4-vision-preview[demo]",
"gpt-4[demo]",
"gpt-3.5-turbo-1106[demo]",
"gpt-3.5-turbo-0613[demo]",
"gpt-3.5-turbo-0301[demo]",
"gpt-3.5-turbo-16k[demo]",
"gpt-3.5-turbo-16k-0613[demo]",
"gpt-3.5-turbo[demo]",
"gpt-3.5-turbo-instruct[demo]",
var chatGPTDemoModels = []ability.Model{
{Name: "gpt-4-32k-0613[demo]", DisplayName: "gpt-4-32k-0613[demo]"},
{Name: "gpt-4-32k-0314[demo]", DisplayName: "gpt-4-32k-0314[demo]"},
{Name: "gpt-4-32k[demo]", DisplayName: "gpt-4-32k[demo]"},
{Name: "gpt-4-0613[demo]", DisplayName: "gpt-4-0613[demo]"},
{Name: "gpt-4-0314[demo]", DisplayName: "gpt-4-0314[demo]"},
{Name: "gpt-4-1106-preview[demo]", DisplayName: "gpt-4-1106-preview[demo]"},
{Name: "gpt-4-vision-preview[demo]", DisplayName: "gpt-4-vision-preview[demo]"},
{Name: "gpt-4[demo]", DisplayName: "gpt-4[demo]"},
{Name: "gpt-3.5-turbo-1106[demo]", DisplayName: "gpt-3.5-turbo-1106[demo]"},
{Name: "gpt-3.5-turbo-0613[demo]", DisplayName: "gpt-3.5-turbo-0613[demo]"},
{Name: "gpt-3.5-turbo-0301[demo]", DisplayName: "gpt-3.5-turbo-0301[demo]"},
{Name: "gpt-3.5-turbo-16k[demo]", DisplayName: "gpt-3.5-turbo-16k[demo]"},
{Name: "gpt-3.5-turbo-16k-0613[demo]", DisplayName: "gpt-3.5-turbo-16k-0613[demo]"},
{Name: "gpt-3.5-turbo[demo]", DisplayName: "gpt-3.5-turbo[demo]"},
{Name: "gpt-3.5-turbo-instruct[demo]", DisplayName: "gpt-3.5-turbo-instruct[demo]"},
}

type chatGPTDemo struct {
Expand Down Expand Up @@ -61,8 +62,6 @@ func (c *chatGPTDemo) Completion(_ context.Context, _ []client.Message, t abilit
// CompletionStream
//
// Return only one chunk that contains the whole content if stream is not supported.
// To make sure the chan closes eventually, caller should either read the last chunk from chan
// or got a chunk whose Err != nil
func (c *chatGPTDemo) CompletionStream(_ context.Context, ms []client.Message, t ability.LLMOption) *util.SmoothStream {
c.logger.Sugar().Debugw("completion stream...", "message list length", len(ms))
stream := util.NewSmoothStream()
Expand All @@ -81,7 +80,7 @@ func (c *chatGPTDemo) CompletionStream(_ context.Context, ms []client.Message, t
time.Sleep(time.Duration(rand.Intn(150)) * time.Millisecond)
}
}
stream.DoneWrite()
stream.WriteError(io.EOF)
}()
return stream
}
Expand Down Expand Up @@ -115,7 +114,7 @@ func (c *chatGPTDemo) Support(o ability.LLMOption) bool {
return o.ChatGPT != nil
}

func (c *chatGPTDemo) getModels(_ context.Context) ([]string, error) {
func (c *chatGPTDemo) getModels(_ context.Context) ([]ability.Model, error) {
c.logger.Info("get models...")

c.logger.Sugar().Debug("models count:", len(chatGPTDemoModels))
Expand Down
Loading

0 comments on commit 7851934

Please sign in to comment.