From cdb32b85d917bc3aa5871bbd9e4dc8db3a4cb47f Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Wed, 2 Oct 2024 13:13:35 +0100 Subject: [PATCH 1/2] Remove config and logging This makes it more library like. It's so long since I wrote a lib that I added a configParser and logging framework, these should instead be managed by whomever is using the library in their own application. I also added a bit more fluff such as CI and a good lint clean up --- .github/workflows/test.yml | 22 ++++++ .golangci.yml | 103 ++++++++++++++++++++++++++ README.md | 44 ++---------- examples/config-example.yaml | 14 ---- examples/main.go | 111 ++++++++--------------------- go.mod | 24 ++----- pkg/backend/ollama_backend.go | 33 +++------ pkg/backend/ollama_backend_test.go | 26 ++++--- pkg/backend/openai_backend.go | 27 ++----- pkg/backend/openai_backend_test.go | 13 +++- pkg/config/config.go | 68 ------------------ pkg/config/config_test.go | 96 ------------------------- pkg/db/pgvector.go | 39 ++++++++-- pkg/logger/logger.go | 61 ---------------- 14 files changed, 245 insertions(+), 436 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 .golangci.yml delete mode 100644 examples/config-example.yaml delete mode 100644 pkg/config/config.go delete mode 100644 pkg/config/config_test.go delete mode 100644 pkg/logger/logger.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..cae7118 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,22 @@ +name: Main +on: + push: + branches: + - main + paths-ignore: + - 'docs/**' +permissions: + contents: read + packages: write + +jobs: + test: + name: Unit testing + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run tests + run: go test -cover ./... + - name: Lint code + run: golangci-lint run diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..b4018db --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,103 @@ +run: + issues-exit-code: 1 + timeout: 5m + +linters-settings: + lll: + line-length: 130 + gocyclo: + min-complexity: 15 + gci: + sections: + - standard + - default + - prefix(github.com/stackloklabs/gollm) + revive: + # see https://github.com/mgechev/revive#available-rules for details. + ignore-generated-header: true + severity: warning + errorCode: 0 + warningCode: 0 + rules: + - name: blank-imports + severity: warning + - name: context-as-argument + - name: context-keys-type + - name: duplicated-imports + - name: error-naming + # - name: error-strings #BDG: This was enabled for months, but it suddenly started working on 3/2/2022.. come to find out we have TONS of error messages starting with capital... disabling for now(ever?) + - name: error-return + - name: exported + severity: error + - name: if-return + # - name: get-return // BDG: We have a lot of API endpoint handlers named like getFoos but write to response vs return... maybe later can figure that out + - name: identical-branches + - name: indent-error-flow + - name: import-shadowing + - name: package-comments + # NOTE: range-val-address and range-val-in-closure are irrelevant in Go 1.22 and later + - name: redefines-builtin-id + - name: struct-tag + - name: unconditional-recursion + - name: unnecessary-stmt + - name: unreachable-code + - name: unused-parameter + - name: unused-receiver + - name: unhandled-error + disabled: true + gosec: + excludes: + - G114 # for the moment we need to use listenandserve that has no support for timeouts + - G404 # use unsafe random generator until logic change is discussed + - G307 # Deferring unsafe method "Close" on type "io.ReadCloser" + - G601 # Irrelevant for Go 1.22 and later, see: https://github.com/securego/gosec/issues/1099 + + depguard: + rules: + prevent_unmaintained_packages: + list-mode: lax # allow unless explicitely denied + files: + - $all + - "!$test" + deny: + - pkg: "log" + desc: "We should use zerolog instead" + - pkg: io/ioutil + desc: "this is deprecated" + +linters: + disable-all: true + enable: + - lll + - exhaustive + - depguard + - goconst + - gocyclo + - gofmt + - gosec + - gci + - unparam + - gosimple + - govet + - ineffassign + - paralleltest + - promlinter + - revive + - staticcheck + - unused + - thelper + - tparallel + +issues: + exclude-use-default: false + exclude-rules: + - path: '(.+)_test\.go' + linters: + - lll + +output: + formats: + - format: colored-line-number + print-issued-lines: true + print-linter-name: true + sort-results: true diff --git a/README.md b/README.md index 00e6882..8bd66a9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gollm: Go Interface for LLM development 📜 +# Gollm: Go Interface for LLM development with RAG 📜 [![Go Report Card](https://goreportcard.com/badge/github.com/stackloklabs/gollm)](https://goreportcard.com/report/github.com/stackloklabs/gollm) [![License](https://img.shields.io/github/license/stackloklabs/gollm)](LICENSE) @@ -19,13 +19,12 @@ Language Model backends including [Ollama](https://ollama.com) and [OpenAI](http ### 1. Installation -First, make sure you have Go installed. Then, add Gollm to your project: +First, make sure you have Go installed. Then, add gollm to your project: ```bash go get github.com/stackloklabs/gollm ``` - ## 2. Setting Up Ollama You'll need to have an Ollama server running and accessible. @@ -35,7 +34,7 @@ Install Ollama Server: Download the server from the [official Ollama website](ht Pull and run a model ```bash -ollama run qwen2.5 +ollama run llama3 ``` Ollama should run on port `11434` and `localhost`, if you change this, don't @@ -43,51 +42,20 @@ forget to update your config. ## 3. OpenAI -You'll need an OpenAI API key to use the OpenAI backend, which can be be -set within the config as below. +You'll need an OpenAI API key to use the OpenAI backend. ## 4. Configuration -Gollm uses Viper to manage configuration settings. - -Backends are configured for either generation or embeddings, and can be set to either Ollama or OpenAI. - -For each backend Models is set. Note that for Ollama you will need to -have this as running model, e.g. `ollama run qwen2.5` or `ollama run mxbai-embed-large`. - -Finally, in the case of RAG embeddings, a database URL is required. - -Currently Postgres is supported, and the database should be created before running the application, with the schena provided in `db/init.sql` +Currently Postgres is supported, and the database should be created before +running the application, with the schena provided in `db/init.sql` Should you wish, the docker-compose will automate the setup of the database. -```bash -cp examples/config-example.yaml ./config.yaml -``` - -```yaml -backend: - embeddings: "ollama" # or "ollama" - generation: "ollama" # or "openai" -ollama: - host: "http://localhost:11434" - gen_model: "qwen2.5" - emb_model: "mxbai-embed-large" -openai: - api_key: "your-key" - gen_model: "gpt-3.5-turbo" - emb_model: "text-embedding-ada-002" -database: - url: "postgres://user:password@localhost:5432/dbname?sslmode=disable" -log_level: "info" -``` - # 🛠️ Usage Best bet is to see `/examples/main.go` for reference, this explains how to use the library with full examples for generation, embeddings and implementing RAG. - # 📝 Contributing We welcome contributions! Please submit a pull request or raise an issue if diff --git a/examples/config-example.yaml b/examples/config-example.yaml deleted file mode 100644 index 471211b..0000000 --- a/examples/config-example.yaml +++ /dev/null @@ -1,14 +0,0 @@ -backend: - embeddings: "ollama" # or "ollama" - generation: "ollama" # or "openai" -ollama: - host: "http://localhost:11434" - gen_model: "qwen2.5" - emb_model: "mxbai-embed-large" -openai: - api_key: "your-key" - gen_model: "gpt-3.5-turbo" - emb_model: "text-embedding-ada-002" -database: - url: "postgres://user:password@localhost:5432/dbname?sslmode=disable" -log_level: "info" \ No newline at end of file diff --git a/examples/main.go b/examples/main.go index d2d61ae..a40511b 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,61 +2,45 @@ package main import ( "context" - "fmt" + "log" "time" - "github.com/google/uuid" - "github.com/stackloklabs/gollm/pkg/backend" - "github.com/stackloklabs/gollm/pkg/config" "github.com/stackloklabs/gollm/pkg/db" - "github.com/stackloklabs/gollm/pkg/logger" +) + +var ( + ollamaHost = "http://localhost:11434" + ollamaEmbModel = "mxbai-embed-large" + ollamaGenModel = "llama3" + databaseURL = "postgres://user:password@localhost:5432/dbname?sslmode=disable" ) func main() { // Initialize Config - cfg := config.InitializeViperConfig("config", "yaml", ".") - - logger.InitLogger() // Select backends based on config var embeddingBackend backend.Backend var generationBackend backend.Backend // Choose the backend for embeddings based on the config - switch cfg.Get("backend.embeddings") { - case "ollama": - embeddingBackend = backend.NewOllamaBackend(cfg.Get("ollama.host"), cfg.Get("ollama.emb_model")) - case "openai": - embeddingBackend = backend.NewOpenAIBackend(cfg.Get("openai.api_key"), cfg.Get("openai.emb_model")) - default: - logger.Fatal("Invalid embeddings backend specified") - } - logger.Info(fmt.Sprintf("Embeddings backend: %s", cfg.Get("backend.embeddings"))) + embeddingBackend = backend.NewOllamaBackend(ollamaHost, ollamaEmbModel) - // Choose the backend for generation based on the config - switch cfg.Get("backend.generation") { - case "ollama": - generationBackend = backend.NewOllamaBackend(cfg.Get("ollama.host"), cfg.Get("ollama.gen_model")) - case "openai": - generationBackend = backend.NewOpenAIBackend(cfg.Get("openai.api_key"), cfg.Get("openai.gen_model")) - default: - logger.Fatal("Invalid generation backend specified") - } + log.Printf("Embedding backend LLM: %s", ollamaEmbModel) - logger.Info(fmt.Sprintf("Generation backend: %s", cfg.Get("backend.generation"))) + // Choose the backend for generation based on the config + generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel) - // Initialize database connection for pgvector - dbConnString := cfg.Get("database.url") + log.Printf("Generation backend: %s", ollamaGenModel) // Initialize the vector database - vectorDB, err := db.NewPGVector(dbConnString) + vectorDB, err := db.NewPGVector(databaseURL) if err != nil { - logger.Fatalf("Failed to initialize vector database: %v", err) + log.Fatalf("Error initializing vector database: %v", err) } - logger.Info("Vector database initialized") + log.Println("Vector database initialized") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -69,79 +53,44 @@ func main() { // Embed the query using OpenAI embedding, err := embeddingBackend.Embed(ctx, ragContent) if err != nil { - logger.Fatalf("Error generating embedding: %v", err) + log.Fatalf("Error generating embedding: %v", err) } - - // Check 1536 is the expected Dimensions value (1536 is the OpenAI default) - // expectedDimensions := 1536 - // if len(embedding) != expectedDimensions { - // logger.Fatalf("Error: embedding dimensions mismatch. Expected %d, got %d", expectedDimensions, len(embedding)) - // } + log.Println("Embedding generated") // Insert the document into the vector store - err = insertDocument(vectorDB, ctx, ragContent, embedding) + err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) if err != nil { - logger.Fatalf("Failed to insert document into vectorDB: %v", err) + log.Fatalf("Error inserting document: %v", err) } + log.Println("Vector Document generated") // Embed the query using the specified embedding backend queryEmbedding, err := embeddingBackend.Embed(ctx, query) if err != nil { - logger.Fatalf("Error generating query embedding: %v", err) + log.Fatalf("Error generating query embedding: %v", err) } + log.Println("Vector embeddings generated") // Retrieve relevant documents for the query embedding - retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, cfg.Get("backend.embeddings")) + retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, "ollama") if err != nil { - logger.Fatalf("Error retrieving documents: %v", err) + log.Fatalf("Error retrieving relevant documents: %v", err) } // Log the retrieved documents to see if they include the inserted content for _, doc := range retrievedDocs { - logger.Infof("RAG Retrieved Document ID: %s, Content: %v", doc.ID, doc.Metadata["content"]) + log.Printf("Retrieved Document: %v", doc) } // Augment the query with retrieved context - augmentedQuery := combineQueryWithContext(query, retrievedDocs) - logger.Infof("Augmented query Constructed using Prompt: %s", query) - - // logger.Infof("Augmented Query: %s", augmentedQuery) + augmentedQuery := db.CombineQueryWithContext(query, retrievedDocs) + log.Printf("LLM Prompt: %s", query) // Generate response with the specified generation backend response, err := generationBackend.Generate(ctx, augmentedQuery) if err != nil { - logger.Fatalf("Failed to generate response: %v", err) + log.Fatalf("Failed to generate response: %v", err) } - logger.Infof("Output from LLM model %s:", response) -} - -// combineQueryWithContext combines the query and retrieved documents' content to provide context for generation. -func combineQueryWithContext(query string, docs []db.Document) string { - var context string - for _, doc := range docs { - // Cast doc.Metadata["content"] to a string - if content, ok := doc.Metadata["content"].(string); ok { - context += content + "\n" - } - } - return fmt.Sprintf("Context: %s\nQuery: %s", context, query) -} - -// Example code to insert a document into the vector store -func insertDocument(vectorDB *db.PGVector, ctx context.Context, content string, embedding []float32) error { - // Generate a unique document ID (for simplicity, using a static value for testing) - docID := fmt.Sprintf("doc-%s", uuid.New().String()) - - // Create metadata - metadata := map[string]interface{}{ - "content": content, - } - - // Save the document and its embedding into the vector store - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) - if err != nil { - return fmt.Errorf("error saving embedding: %v", err) - } - return nil + log.Printf("Retrieval-Augmented Generation influenced output from LLM model: %s", response) } diff --git a/go.mod b/go.mod index e762c97..15948f8 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,10 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v4 v4.18.3 github.com/pgvector/pgvector-go v0.2.2 - github.com/rs/zerolog v1.15.0 - github.com/spf13/viper v1.19.0 ) require ( - github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/hashicorp/hcl v1.0.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect @@ -21,22 +18,9 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle v1.3.0 // indirect - github.com/magiconair/properties v1.8.7 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect - github.com/sagikazarmark/slog-shim v0.1.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.9.0 // indirect golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect - gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 9b26bc5..36d8b4b 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -22,8 +22,6 @@ import ( "io" "net/http" "time" - - "github.com/stackloklabs/gollm/pkg/logger" ) const ( @@ -82,47 +80,41 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return "", fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - logger.Infof("Sending augmented prompt to Ollama LLM model: %s", o.Model) - resp, err := o.Client.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return "", fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return "", fmt.Errorf("failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return "", fmt.Errorf( + "failed to generate response from Ollama: "+ + "status code %d, response: %s", + resp.StatusCode, string(bodyBytes), + ) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return "", fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received response from Ollama for LLM model %s", result.Model) - return result.Response, nil } // Embed generates embeddings for the given input text using the Ollama API. func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) { - logger.Infof("Ollama Embedding model %s prompt: %s", o.Model, input) - logger.Infof("Sending request to Ollama API at %s with Embedding model: %s", o.BaseURL, o.Model) + url := o.BaseURL + embedEndpoint reqBody := map[string]interface{}{ "model": o.Model, @@ -131,37 +123,34 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := o.Client.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return nil, fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate embeddings from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return nil, fmt.Errorf("failed to generate embeddings from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf( + "failed to generate embeddings from Ollama: "+ + "status code %d, response: %s", + resp.StatusCode, string(bodyBytes), + ) } var result OllamaEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return nil, fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received vector embeddings from Ollama model %s", o.Model) - return result.Embedding, nil } diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index f9eea60..e24cbe6 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -22,7 +22,11 @@ import ( "time" ) +const contentTypeJSON = "application/json" +const testEmbeddingText = "Test embedding text." + func TestOllamaGenerate(t *testing.T) { + t.Parallel() // Mock response from Ollama API mockResponse := Response{ Model: "test-model", @@ -39,7 +43,8 @@ func TestOllamaGenerate(t *testing.T) { } // Check Content-Type header - if r.Header.Get("Content-Type") != "application/json" { + + if r.Header.Get("Content-Type") != contentTypeJSON { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } @@ -61,8 +66,10 @@ func TestOllamaGenerate(t *testing.T) { } // Write the mock response - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + w.Header().Set("Content-Type", contentTypeJSON) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -88,6 +95,7 @@ func TestOllamaGenerate(t *testing.T) { } func TestOllamaEmbed(t *testing.T) { + t.Parallel() // Mock response from Ollama API mockResponse := OllamaEmbeddingResponse{ Embedding: []float32{0.1, 0.2, 0.3}, @@ -101,7 +109,7 @@ func TestOllamaEmbed(t *testing.T) { } // Check Content-Type header - if r.Header.Get("Content-Type") != "application/json" { + if r.Header.Get("Content-Type") != contentTypeJSON { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } @@ -115,13 +123,15 @@ func TestOllamaEmbed(t *testing.T) { if reqBody["model"] != "test-model" { t.Errorf("Expected model 'test-model', got '%v'", reqBody["model"]) } - if reqBody["prompt"] != "Test embedding text." { + if reqBody["prompt"] != testEmbeddingText { t.Errorf("Expected prompt 'Test embedding text.', got '%v'", reqBody["prompt"]) } // Write the mock response - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + w.Header().Set("Content-Type", contentTypeJSON) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -133,7 +143,7 @@ func TestOllamaEmbed(t *testing.T) { } ctx := context.Background() - input := "Test embedding text." + input := testEmbeddingText embedding, err := backend.Embed(ctx, input) if err != nil { diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index 34bd0ae..f726a58 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -21,8 +21,6 @@ import ( "fmt" "io" "net/http" - - "github.com/stackloklabs/gollm/pkg/logger" ) // OpenAIBackend represents a backend for interacting with the OpenAI API. @@ -111,42 +109,35 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, er reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) + return "", fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.APIKey) - logger.Infof("Sending augmented prompt to OpenAI API at %s with model %s", url, o.Model) - resp, err := o.HTTPClient.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return "", fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return "", fmt.Errorf("failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return "", fmt.Errorf("failed to generate response from OpenAI: "+ + "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) } var result OpenAIResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return "", fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received response from OpenAI for model %s", result.Model) - return result.Choices[0].Message.Content, nil } @@ -160,8 +151,6 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, er // - A slice of float32 values representing the embedding vector. // - An error if the API request fails or if there's an issue processing the response. func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, error) { - logger.Infof("OpenAI Embedding model %s prompt: %s", o.Model, text) - logger.Infof("Sending request to OpenAI API at %s with Embedding model: %s", o.BaseURL, o.Model) url := o.BaseURL + "/v1/embeddings" reqBody := map[string]interface{}{ "model": o.Model, @@ -170,13 +159,11 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) } @@ -185,24 +172,20 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro resp, err := o.HTTPClient.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return nil, fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate embedding from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return nil, fmt.Errorf("failed to generate embedding from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("failed to generate embedding from OpenAI: "+ + "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) } var result OpenAIEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return nil, fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received vector embeddings from OpenAI model %s", o.Model) - return result.Data[0].Embedding, nil } diff --git a/pkg/backend/openai_backend_test.go b/pkg/backend/openai_backend_test.go index 5593923..93a5965 100644 --- a/pkg/backend/openai_backend_test.go +++ b/pkg/backend/openai_backend_test.go @@ -23,6 +23,7 @@ import ( ) func TestGenerate(t *testing.T) { + t.Parallel() // Mock response from OpenAI API mockResponse := OpenAIResponse{ ID: "test-id", @@ -76,7 +77,12 @@ func TestGenerate(t *testing.T) { // Write the mock response w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -104,6 +110,7 @@ func TestGenerate(t *testing.T) { } func TestGenerateEmbedding(t *testing.T) { + t.Parallel() // Mock response from OpenAI API mockResponse := OpenAIEmbeddingResponse{ Object: "list", @@ -144,7 +151,9 @@ func TestGenerateEmbedding(t *testing.T) { // Write the mock response w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 19ba8ac..0000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Stacklok, Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "github.com/spf13/viper" - "log" -) - -type Config interface { - Get(key string) string - GetInt(key string) int - GetBool(key string) bool -} - -// ViperConfig implements the Config interface using Viper. -type ViperConfig struct { - viper *viper.Viper -} - -// NewViperConfig initializes a ViperConfig with a given Viper instance. -func NewViperConfig(v *viper.Viper) *ViperConfig { - return &ViperConfig{viper: v} -} - -// Get returns a string value for the given key. -func (vc *ViperConfig) Get(key string) string { - return vc.viper.GetString(key) -} - -// GetInt returns an integer value for the given key. -func (vc *ViperConfig) GetInt(key string) int { - return vc.viper.GetInt(key) -} - -// GetBool returns a boolean value for the given key. -func (vc *ViperConfig) GetBool(key string) bool { - return vc.viper.GetBool(key) -} - -// InitializeViperConfig initializes and returns a Config implementation using Viper. -// It reads the configuration from the specified config file and paths. -func InitializeViperConfig(configName, configType, configPath string) Config { - v := viper.New() - v.SetConfigName(configName) - v.SetConfigType(configType) - v.AddConfigPath(configPath) - - // Read in the config file - if err := v.ReadInConfig(); err != nil { - log.Fatalf("Error reading config file: %v", err) - } - - // Wrap Viper with ViperConfig and return as Config - return NewViperConfig(v) -} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go deleted file mode 100644 index 98eb547..0000000 --- a/pkg/config/config_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// config_test.go -package config - -import ( - "github.com/spf13/viper" - "os" - "testing" -) - -func TestViperConfig_Get(t *testing.T) { - // Create a new Viper instance and set a string value - v := viper.New() - v.Set("stringKey", "stringValue") - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the Get method - value := vc.Get("stringKey") - if value != "stringValue" { - t.Errorf("Expected 'stringValue', got '%s'", value) - } -} - -func TestViperConfig_GetInt(t *testing.T) { - // Create a new Viper instance and set an integer value - v := viper.New() - v.Set("intKey", 42) - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the GetInt method - value := vc.GetInt("intKey") - if value != 42 { - t.Errorf("Expected 42, got %d", value) - } -} - -func TestViperConfig_GetBool(t *testing.T) { - // Create a new Viper instance and set a boolean value - v := viper.New() - v.Set("boolKey", true) - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the GetBool method - value := vc.GetBool("boolKey") - if value != true { - t.Errorf("Expected true, got %v", value) - } -} - -func TestInitializeViperConfig(t *testing.T) { - // Since InitializeViperConfig reads from a file, we'll create a temporary config file for testing - configName := "testconfig" - configType := "yaml" - configPath := "." - - // Create a temporary config file with some test data - testConfigContent := ` -stringKey: stringValue -intKey: 42 -boolKey: true -` - // Write the test config content to a temporary file - configFileName := configName + "." + configType - err := writeTempConfigFile(configFileName, testConfigContent) - if err != nil { - t.Fatalf("Failed to write temp config file: %v", err) - } - defer removeTempConfigFile(configFileName) - - // Initialize the config - cfg := InitializeViperConfig(configName, configType, configPath) - - // Test the values - if cfg.Get("stringKey") != "stringValue" { - t.Errorf("Expected 'stringValue', got '%s'", cfg.Get("stringKey")) - } - if cfg.GetInt("intKey") != 42 { - t.Errorf("Expected 42, got %d", cfg.GetInt("intKey")) - } - if cfg.GetBool("boolKey") != true { - t.Errorf("Expected true, got %v", cfg.GetBool("boolKey")) - } -} - -func writeTempConfigFile(filename, content string) error { - return os.WriteFile(filename, []byte(content), 0644) -} - -func removeTempConfigFile(filename string) { - os.Remove(filename) -} diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index beb4135..445da21 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. package db import ( @@ -19,10 +23,9 @@ import ( "fmt" "strings" + "github.com/google/uuid" "github.com/jackc/pgx/v4/pgxpool" "github.com/pgvector/pgvector-go" - - "github.com/stackloklabs/gollm/pkg/logger" ) // PGVector represents a connection to a PostgreSQL database with pgvector extension. @@ -86,8 +89,6 @@ func (pg *PGVector) SaveEmbedding(ctx context.Context, docID string, embedding [ // Log the error for debugging purposes return fmt.Errorf("failed to insert document: %w", err) } - - logger.Infof("Document inserted successfully: %s", docID) return nil } @@ -165,3 +166,33 @@ func ConvertEmbeddingToPGVector(embedding []float32) string { } return fmt.Sprintf("{%s}", strings.Join(strValues, ",")) } + +// CombineQueryWithContext combines the query and retrieved documents' content to provide context for generation. +func CombineQueryWithContext(query string, docs []Document) string { + var contextStr string + for _, doc := range docs { + // Cast doc.Metadata["content"] to a string + if content, ok := doc.Metadata["content"].(string); ok { + contextStr += content + "\n" + } + } + return fmt.Sprintf("Context: %s\nQuery: %s", contextStr, query) +} + +// InsertDocument insert a document into the vector store +func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, embedding []float32) error { + // Generate a unique document ID (for simplicity, using a static value for testing) + docID := fmt.Sprintf("doc-%s", uuid.New().String()) + + // Create metadata + metadata := map[string]interface{}{ + "content": content, + } + + // Save the document and its embedding into the vector store + err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go deleted file mode 100644 index 6736fad..0000000 --- a/pkg/logger/logger.go +++ /dev/null @@ -1,61 +0,0 @@ -package logger - -import ( - "os" - - "github.com/rs/zerolog" - "github.com/stackloklabs/gollm/pkg/config" -) - -var logger zerolog.Logger - -func InitLogger() { - cfg := config.InitializeViperConfig("config", "yaml", ".") - // Set the time format for logs - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - - // Retrieve log level from configuration (using viper as an example) - log_level := cfg.Get("log_level") - - // Parse log level - level, err := zerolog.ParseLevel(log_level) - if err != nil { - level = zerolog.InfoLevel // Default to InfoLevel - } - - logger = zerolog.New(os.Stdout).With().Timestamp().Logger() - zerolog.SetGlobalLevel(level) -} - -// Info logs an info level message -func Info(msg string) { - logger.Info().Msg(msg) -} - -// Infof logs an info level message with formatting -func Infof(format string, v ...interface{}) { - logger.Info().Msgf(format, v...) -} - -// Debug logs a debug level message -func Debug(msg string) { - logger.Debug().Msg(msg) -} - -// Fatal logs a fatal level message and then exits the program -func Fatal(msg string) { - logger.Fatal().Msg(msg) -} - -// Fatalf logs a fatal level message with formatting and then exits the program -func Fatalf(format string, v ...interface{}) { - logger.Fatal().Msgf(format, v...) -} - -func Errorf(format string, v ...interface{}) { - logger.Error().Msgf(format, v...) -} - -func Error(msg string) { - logger.Error().Msg(msg) -} From 686bfd47bb996ded7913d2144e6ed36aee3fe220 Mon Sep 17 00:00:00 2001 From: Luke Hinds Date: Wed, 2 Oct 2024 13:13:35 +0100 Subject: [PATCH 2/2] Remove config and logging This makes it more library like. It's so long since I wrote a lib that I added a configParser and logging framework, these should instead be managed by whomever is using the library in their own application. I also added a bit more fluff such as CI and a good lint clean up --- .github/workflows/test.yml | 22 ++++++ .golangci.yml | 103 ++++++++++++++++++++++++++ README.md | 44 ++---------- examples/config-example.yaml | 14 ---- examples/main.go | 111 ++++++++--------------------- go.mod | 24 ++----- pkg/backend/ollama_backend.go | 33 +++------ pkg/backend/ollama_backend_test.go | 26 ++++--- pkg/backend/openai_backend.go | 27 ++----- pkg/backend/openai_backend_test.go | 13 +++- pkg/config/config.go | 68 ------------------ pkg/config/config_test.go | 96 ------------------------- pkg/db/pgvector.go | 39 ++++++++-- pkg/logger/logger.go | 61 ---------------- 14 files changed, 245 insertions(+), 436 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 .golangci.yml delete mode 100644 examples/config-example.yaml delete mode 100644 pkg/config/config.go delete mode 100644 pkg/config/config_test.go delete mode 100644 pkg/logger/logger.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..cae7118 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,22 @@ +name: Main +on: + push: + branches: + - main + paths-ignore: + - 'docs/**' +permissions: + contents: read + packages: write + +jobs: + test: + name: Unit testing + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Run tests + run: go test -cover ./... + - name: Lint code + run: golangci-lint run diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..b4018db --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,103 @@ +run: + issues-exit-code: 1 + timeout: 5m + +linters-settings: + lll: + line-length: 130 + gocyclo: + min-complexity: 15 + gci: + sections: + - standard + - default + - prefix(github.com/stackloklabs/gollm) + revive: + # see https://github.com/mgechev/revive#available-rules for details. + ignore-generated-header: true + severity: warning + errorCode: 0 + warningCode: 0 + rules: + - name: blank-imports + severity: warning + - name: context-as-argument + - name: context-keys-type + - name: duplicated-imports + - name: error-naming + # - name: error-strings #BDG: This was enabled for months, but it suddenly started working on 3/2/2022.. come to find out we have TONS of error messages starting with capital... disabling for now(ever?) + - name: error-return + - name: exported + severity: error + - name: if-return + # - name: get-return // BDG: We have a lot of API endpoint handlers named like getFoos but write to response vs return... maybe later can figure that out + - name: identical-branches + - name: indent-error-flow + - name: import-shadowing + - name: package-comments + # NOTE: range-val-address and range-val-in-closure are irrelevant in Go 1.22 and later + - name: redefines-builtin-id + - name: struct-tag + - name: unconditional-recursion + - name: unnecessary-stmt + - name: unreachable-code + - name: unused-parameter + - name: unused-receiver + - name: unhandled-error + disabled: true + gosec: + excludes: + - G114 # for the moment we need to use listenandserve that has no support for timeouts + - G404 # use unsafe random generator until logic change is discussed + - G307 # Deferring unsafe method "Close" on type "io.ReadCloser" + - G601 # Irrelevant for Go 1.22 and later, see: https://github.com/securego/gosec/issues/1099 + + depguard: + rules: + prevent_unmaintained_packages: + list-mode: lax # allow unless explicitely denied + files: + - $all + - "!$test" + deny: + - pkg: "log" + desc: "We should use zerolog instead" + - pkg: io/ioutil + desc: "this is deprecated" + +linters: + disable-all: true + enable: + - lll + - exhaustive + - depguard + - goconst + - gocyclo + - gofmt + - gosec + - gci + - unparam + - gosimple + - govet + - ineffassign + - paralleltest + - promlinter + - revive + - staticcheck + - unused + - thelper + - tparallel + +issues: + exclude-use-default: false + exclude-rules: + - path: '(.+)_test\.go' + linters: + - lll + +output: + formats: + - format: colored-line-number + print-issued-lines: true + print-linter-name: true + sort-results: true diff --git a/README.md b/README.md index 00e6882..8bd66a9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gollm: Go Interface for LLM development 📜 +# Gollm: Go Interface for LLM development with RAG 📜 [![Go Report Card](https://goreportcard.com/badge/github.com/stackloklabs/gollm)](https://goreportcard.com/report/github.com/stackloklabs/gollm) [![License](https://img.shields.io/github/license/stackloklabs/gollm)](LICENSE) @@ -19,13 +19,12 @@ Language Model backends including [Ollama](https://ollama.com) and [OpenAI](http ### 1. Installation -First, make sure you have Go installed. Then, add Gollm to your project: +First, make sure you have Go installed. Then, add gollm to your project: ```bash go get github.com/stackloklabs/gollm ``` - ## 2. Setting Up Ollama You'll need to have an Ollama server running and accessible. @@ -35,7 +34,7 @@ Install Ollama Server: Download the server from the [official Ollama website](ht Pull and run a model ```bash -ollama run qwen2.5 +ollama run llama3 ``` Ollama should run on port `11434` and `localhost`, if you change this, don't @@ -43,51 +42,20 @@ forget to update your config. ## 3. OpenAI -You'll need an OpenAI API key to use the OpenAI backend, which can be be -set within the config as below. +You'll need an OpenAI API key to use the OpenAI backend. ## 4. Configuration -Gollm uses Viper to manage configuration settings. - -Backends are configured for either generation or embeddings, and can be set to either Ollama or OpenAI. - -For each backend Models is set. Note that for Ollama you will need to -have this as running model, e.g. `ollama run qwen2.5` or `ollama run mxbai-embed-large`. - -Finally, in the case of RAG embeddings, a database URL is required. - -Currently Postgres is supported, and the database should be created before running the application, with the schena provided in `db/init.sql` +Currently Postgres is supported, and the database should be created before +running the application, with the schena provided in `db/init.sql` Should you wish, the docker-compose will automate the setup of the database. -```bash -cp examples/config-example.yaml ./config.yaml -``` - -```yaml -backend: - embeddings: "ollama" # or "ollama" - generation: "ollama" # or "openai" -ollama: - host: "http://localhost:11434" - gen_model: "qwen2.5" - emb_model: "mxbai-embed-large" -openai: - api_key: "your-key" - gen_model: "gpt-3.5-turbo" - emb_model: "text-embedding-ada-002" -database: - url: "postgres://user:password@localhost:5432/dbname?sslmode=disable" -log_level: "info" -``` - # 🛠️ Usage Best bet is to see `/examples/main.go` for reference, this explains how to use the library with full examples for generation, embeddings and implementing RAG. - # 📝 Contributing We welcome contributions! Please submit a pull request or raise an issue if diff --git a/examples/config-example.yaml b/examples/config-example.yaml deleted file mode 100644 index 471211b..0000000 --- a/examples/config-example.yaml +++ /dev/null @@ -1,14 +0,0 @@ -backend: - embeddings: "ollama" # or "ollama" - generation: "ollama" # or "openai" -ollama: - host: "http://localhost:11434" - gen_model: "qwen2.5" - emb_model: "mxbai-embed-large" -openai: - api_key: "your-key" - gen_model: "gpt-3.5-turbo" - emb_model: "text-embedding-ada-002" -database: - url: "postgres://user:password@localhost:5432/dbname?sslmode=disable" -log_level: "info" \ No newline at end of file diff --git a/examples/main.go b/examples/main.go index d2d61ae..a40511b 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,61 +2,45 @@ package main import ( "context" - "fmt" + "log" "time" - "github.com/google/uuid" - "github.com/stackloklabs/gollm/pkg/backend" - "github.com/stackloklabs/gollm/pkg/config" "github.com/stackloklabs/gollm/pkg/db" - "github.com/stackloklabs/gollm/pkg/logger" +) + +var ( + ollamaHost = "http://localhost:11434" + ollamaEmbModel = "mxbai-embed-large" + ollamaGenModel = "llama3" + databaseURL = "postgres://user:password@localhost:5432/dbname?sslmode=disable" ) func main() { // Initialize Config - cfg := config.InitializeViperConfig("config", "yaml", ".") - - logger.InitLogger() // Select backends based on config var embeddingBackend backend.Backend var generationBackend backend.Backend // Choose the backend for embeddings based on the config - switch cfg.Get("backend.embeddings") { - case "ollama": - embeddingBackend = backend.NewOllamaBackend(cfg.Get("ollama.host"), cfg.Get("ollama.emb_model")) - case "openai": - embeddingBackend = backend.NewOpenAIBackend(cfg.Get("openai.api_key"), cfg.Get("openai.emb_model")) - default: - logger.Fatal("Invalid embeddings backend specified") - } - logger.Info(fmt.Sprintf("Embeddings backend: %s", cfg.Get("backend.embeddings"))) + embeddingBackend = backend.NewOllamaBackend(ollamaHost, ollamaEmbModel) - // Choose the backend for generation based on the config - switch cfg.Get("backend.generation") { - case "ollama": - generationBackend = backend.NewOllamaBackend(cfg.Get("ollama.host"), cfg.Get("ollama.gen_model")) - case "openai": - generationBackend = backend.NewOpenAIBackend(cfg.Get("openai.api_key"), cfg.Get("openai.gen_model")) - default: - logger.Fatal("Invalid generation backend specified") - } + log.Printf("Embedding backend LLM: %s", ollamaEmbModel) - logger.Info(fmt.Sprintf("Generation backend: %s", cfg.Get("backend.generation"))) + // Choose the backend for generation based on the config + generationBackend = backend.NewOllamaBackend(ollamaHost, ollamaGenModel) - // Initialize database connection for pgvector - dbConnString := cfg.Get("database.url") + log.Printf("Generation backend: %s", ollamaGenModel) // Initialize the vector database - vectorDB, err := db.NewPGVector(dbConnString) + vectorDB, err := db.NewPGVector(databaseURL) if err != nil { - logger.Fatalf("Failed to initialize vector database: %v", err) + log.Fatalf("Error initializing vector database: %v", err) } - logger.Info("Vector database initialized") + log.Println("Vector database initialized") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -69,79 +53,44 @@ func main() { // Embed the query using OpenAI embedding, err := embeddingBackend.Embed(ctx, ragContent) if err != nil { - logger.Fatalf("Error generating embedding: %v", err) + log.Fatalf("Error generating embedding: %v", err) } - - // Check 1536 is the expected Dimensions value (1536 is the OpenAI default) - // expectedDimensions := 1536 - // if len(embedding) != expectedDimensions { - // logger.Fatalf("Error: embedding dimensions mismatch. Expected %d, got %d", expectedDimensions, len(embedding)) - // } + log.Println("Embedding generated") // Insert the document into the vector store - err = insertDocument(vectorDB, ctx, ragContent, embedding) + err = db.InsertDocument(ctx, vectorDB, ragContent, embedding) if err != nil { - logger.Fatalf("Failed to insert document into vectorDB: %v", err) + log.Fatalf("Error inserting document: %v", err) } + log.Println("Vector Document generated") // Embed the query using the specified embedding backend queryEmbedding, err := embeddingBackend.Embed(ctx, query) if err != nil { - logger.Fatalf("Error generating query embedding: %v", err) + log.Fatalf("Error generating query embedding: %v", err) } + log.Println("Vector embeddings generated") // Retrieve relevant documents for the query embedding - retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, cfg.Get("backend.embeddings")) + retrievedDocs, err := vectorDB.QueryRelevantDocuments(ctx, queryEmbedding, "ollama") if err != nil { - logger.Fatalf("Error retrieving documents: %v", err) + log.Fatalf("Error retrieving relevant documents: %v", err) } // Log the retrieved documents to see if they include the inserted content for _, doc := range retrievedDocs { - logger.Infof("RAG Retrieved Document ID: %s, Content: %v", doc.ID, doc.Metadata["content"]) + log.Printf("Retrieved Document: %v", doc) } // Augment the query with retrieved context - augmentedQuery := combineQueryWithContext(query, retrievedDocs) - logger.Infof("Augmented query Constructed using Prompt: %s", query) - - // logger.Infof("Augmented Query: %s", augmentedQuery) + augmentedQuery := db.CombineQueryWithContext(query, retrievedDocs) + log.Printf("LLM Prompt: %s", query) // Generate response with the specified generation backend response, err := generationBackend.Generate(ctx, augmentedQuery) if err != nil { - logger.Fatalf("Failed to generate response: %v", err) + log.Fatalf("Failed to generate response: %v", err) } - logger.Infof("Output from LLM model %s:", response) -} - -// combineQueryWithContext combines the query and retrieved documents' content to provide context for generation. -func combineQueryWithContext(query string, docs []db.Document) string { - var context string - for _, doc := range docs { - // Cast doc.Metadata["content"] to a string - if content, ok := doc.Metadata["content"].(string); ok { - context += content + "\n" - } - } - return fmt.Sprintf("Context: %s\nQuery: %s", context, query) -} - -// Example code to insert a document into the vector store -func insertDocument(vectorDB *db.PGVector, ctx context.Context, content string, embedding []float32) error { - // Generate a unique document ID (for simplicity, using a static value for testing) - docID := fmt.Sprintf("doc-%s", uuid.New().String()) - - // Create metadata - metadata := map[string]interface{}{ - "content": content, - } - - // Save the document and its embedding into the vector store - err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) - if err != nil { - return fmt.Errorf("error saving embedding: %v", err) - } - return nil + log.Printf("Retrieval-Augmented Generation influenced output from LLM model: %s", response) } diff --git a/go.mod b/go.mod index e762c97..15948f8 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,10 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v4 v4.18.3 github.com/pgvector/pgvector-go v0.2.2 - github.com/rs/zerolog v1.15.0 - github.com/spf13/viper v1.19.0 ) require ( - github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/hashicorp/hcl v1.0.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect github.com/jackc/pgio v1.0.0 // indirect @@ -21,22 +18,9 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle v1.3.0 // indirect - github.com/magiconair/properties v1.8.7 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect - github.com/sagikazarmark/slog-shim v0.1.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.9.0 // indirect golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect - gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/backend/ollama_backend.go b/pkg/backend/ollama_backend.go index 9b26bc5..36d8b4b 100644 --- a/pkg/backend/ollama_backend.go +++ b/pkg/backend/ollama_backend.go @@ -22,8 +22,6 @@ import ( "io" "net/http" "time" - - "github.com/stackloklabs/gollm/pkg/logger" ) const ( @@ -82,47 +80,41 @@ func (o *OllamaBackend) Generate(ctx context.Context, prompt string) (string, er reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return "", fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - logger.Infof("Sending augmented prompt to Ollama LLM model: %s", o.Model) - resp, err := o.Client.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return "", fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return "", fmt.Errorf("failed to generate response from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return "", fmt.Errorf( + "failed to generate response from Ollama: "+ + "status code %d, response: %s", + resp.StatusCode, string(bodyBytes), + ) } var result Response if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return "", fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received response from Ollama for LLM model %s", result.Model) - return result.Response, nil } // Embed generates embeddings for the given input text using the Ollama API. func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, error) { - logger.Infof("Ollama Embedding model %s prompt: %s", o.Model, input) - logger.Infof("Sending request to Ollama API at %s with Embedding model: %s", o.BaseURL, o.Model) + url := o.BaseURL + embedEndpoint reqBody := map[string]interface{}{ "model": o.Model, @@ -131,37 +123,34 @@ func (o *OllamaBackend) Embed(ctx context.Context, input string) ([]float32, err reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := o.Client.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return nil, fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate embeddings from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return nil, fmt.Errorf("failed to generate embeddings from Ollama: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf( + "failed to generate embeddings from Ollama: "+ + "status code %d, response: %s", + resp.StatusCode, string(bodyBytes), + ) } var result OllamaEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return nil, fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received vector embeddings from Ollama model %s", o.Model) - return result.Embedding, nil } diff --git a/pkg/backend/ollama_backend_test.go b/pkg/backend/ollama_backend_test.go index f9eea60..e24cbe6 100644 --- a/pkg/backend/ollama_backend_test.go +++ b/pkg/backend/ollama_backend_test.go @@ -22,7 +22,11 @@ import ( "time" ) +const contentTypeJSON = "application/json" +const testEmbeddingText = "Test embedding text." + func TestOllamaGenerate(t *testing.T) { + t.Parallel() // Mock response from Ollama API mockResponse := Response{ Model: "test-model", @@ -39,7 +43,8 @@ func TestOllamaGenerate(t *testing.T) { } // Check Content-Type header - if r.Header.Get("Content-Type") != "application/json" { + + if r.Header.Get("Content-Type") != contentTypeJSON { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } @@ -61,8 +66,10 @@ func TestOllamaGenerate(t *testing.T) { } // Write the mock response - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + w.Header().Set("Content-Type", contentTypeJSON) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -88,6 +95,7 @@ func TestOllamaGenerate(t *testing.T) { } func TestOllamaEmbed(t *testing.T) { + t.Parallel() // Mock response from Ollama API mockResponse := OllamaEmbeddingResponse{ Embedding: []float32{0.1, 0.2, 0.3}, @@ -101,7 +109,7 @@ func TestOllamaEmbed(t *testing.T) { } // Check Content-Type header - if r.Header.Get("Content-Type") != "application/json" { + if r.Header.Get("Content-Type") != contentTypeJSON { t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) } @@ -115,13 +123,15 @@ func TestOllamaEmbed(t *testing.T) { if reqBody["model"] != "test-model" { t.Errorf("Expected model 'test-model', got '%v'", reqBody["model"]) } - if reqBody["prompt"] != "Test embedding text." { + if reqBody["prompt"] != testEmbeddingText { t.Errorf("Expected prompt 'Test embedding text.', got '%v'", reqBody["prompt"]) } // Write the mock response - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + w.Header().Set("Content-Type", contentTypeJSON) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -133,7 +143,7 @@ func TestOllamaEmbed(t *testing.T) { } ctx := context.Background() - input := "Test embedding text." + input := testEmbeddingText embedding, err := backend.Embed(ctx, input) if err != nil { diff --git a/pkg/backend/openai_backend.go b/pkg/backend/openai_backend.go index 34bd0ae..f726a58 100644 --- a/pkg/backend/openai_backend.go +++ b/pkg/backend/openai_backend.go @@ -21,8 +21,6 @@ import ( "fmt" "io" "net/http" - - "github.com/stackloklabs/gollm/pkg/logger" ) // OpenAIBackend represents a backend for interacting with the OpenAI API. @@ -111,42 +109,35 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, er reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) + return "", fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+o.APIKey) - logger.Infof("Sending augmented prompt to OpenAI API at %s with model %s", url, o.Model) - resp, err := o.HTTPClient.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return "", fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return "", fmt.Errorf("failed to generate response from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return "", fmt.Errorf("failed to generate response from OpenAI: "+ + "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) } var result OpenAIResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return "", fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received response from OpenAI for model %s", result.Model) - return result.Choices[0].Message.Content, nil } @@ -160,8 +151,6 @@ func (o *OpenAIBackend) Generate(ctx context.Context, prompt string) (string, er // - A slice of float32 values representing the embedding vector. // - An error if the API request fails or if there's an issue processing the response. func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, error) { - logger.Infof("OpenAI Embedding model %s prompt: %s", o.Model, text) - logger.Infof("Sending request to OpenAI API at %s with Embedding model: %s", o.BaseURL, o.Model) url := o.BaseURL + "/v1/embeddings" reqBody := map[string]interface{}{ "model": o.Model, @@ -170,13 +159,11 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - logger.Errorf("Failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBodyBytes)) if err != nil { - logger.Errorf("Failed to create request: %v", err) return nil, fmt.Errorf("failed to create request: %w", err) } @@ -185,24 +172,20 @@ func (o *OpenAIBackend) Embed(ctx context.Context, text string) ([]float32, erro resp, err := o.HTTPClient.Do(req) if err != nil { - logger.Errorf("HTTP request failed: %v", err) return nil, fmt.Errorf("HTTP request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) - logger.Errorf("Failed to generate embedding from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) - return nil, fmt.Errorf("failed to generate embedding from OpenAI: status code %d, response: %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("failed to generate embedding from OpenAI: "+ + "status code %d, response: %s", resp.StatusCode, string(bodyBytes)) } var result OpenAIEmbeddingResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.Errorf("Failed to decode response: %v", err) return nil, fmt.Errorf("failed to decode response: %w", err) } - logger.Infof("Received vector embeddings from OpenAI model %s", o.Model) - return result.Data[0].Embedding, nil } diff --git a/pkg/backend/openai_backend_test.go b/pkg/backend/openai_backend_test.go index 5593923..93a5965 100644 --- a/pkg/backend/openai_backend_test.go +++ b/pkg/backend/openai_backend_test.go @@ -23,6 +23,7 @@ import ( ) func TestGenerate(t *testing.T) { + t.Parallel() // Mock response from OpenAI API mockResponse := OpenAIResponse{ ID: "test-id", @@ -76,7 +77,12 @@ func TestGenerate(t *testing.T) { // Write the mock response w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() @@ -104,6 +110,7 @@ func TestGenerate(t *testing.T) { } func TestGenerateEmbedding(t *testing.T) { + t.Parallel() // Mock response from OpenAI API mockResponse := OpenAIEmbeddingResponse{ Object: "list", @@ -144,7 +151,9 @@ func TestGenerateEmbedding(t *testing.T) { // Write the mock response w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(mockResponse) + if err := json.NewEncoder(w).Encode(mockResponse); err != nil { + t.Errorf("Failed to encode mock response: %v", err) + } })) defer mockServer.Close() diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 19ba8ac..0000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Stacklok, Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package config - -import ( - "github.com/spf13/viper" - "log" -) - -type Config interface { - Get(key string) string - GetInt(key string) int - GetBool(key string) bool -} - -// ViperConfig implements the Config interface using Viper. -type ViperConfig struct { - viper *viper.Viper -} - -// NewViperConfig initializes a ViperConfig with a given Viper instance. -func NewViperConfig(v *viper.Viper) *ViperConfig { - return &ViperConfig{viper: v} -} - -// Get returns a string value for the given key. -func (vc *ViperConfig) Get(key string) string { - return vc.viper.GetString(key) -} - -// GetInt returns an integer value for the given key. -func (vc *ViperConfig) GetInt(key string) int { - return vc.viper.GetInt(key) -} - -// GetBool returns a boolean value for the given key. -func (vc *ViperConfig) GetBool(key string) bool { - return vc.viper.GetBool(key) -} - -// InitializeViperConfig initializes and returns a Config implementation using Viper. -// It reads the configuration from the specified config file and paths. -func InitializeViperConfig(configName, configType, configPath string) Config { - v := viper.New() - v.SetConfigName(configName) - v.SetConfigType(configType) - v.AddConfigPath(configPath) - - // Read in the config file - if err := v.ReadInConfig(); err != nil { - log.Fatalf("Error reading config file: %v", err) - } - - // Wrap Viper with ViperConfig and return as Config - return NewViperConfig(v) -} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go deleted file mode 100644 index 98eb547..0000000 --- a/pkg/config/config_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// config_test.go -package config - -import ( - "github.com/spf13/viper" - "os" - "testing" -) - -func TestViperConfig_Get(t *testing.T) { - // Create a new Viper instance and set a string value - v := viper.New() - v.Set("stringKey", "stringValue") - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the Get method - value := vc.Get("stringKey") - if value != "stringValue" { - t.Errorf("Expected 'stringValue', got '%s'", value) - } -} - -func TestViperConfig_GetInt(t *testing.T) { - // Create a new Viper instance and set an integer value - v := viper.New() - v.Set("intKey", 42) - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the GetInt method - value := vc.GetInt("intKey") - if value != 42 { - t.Errorf("Expected 42, got %d", value) - } -} - -func TestViperConfig_GetBool(t *testing.T) { - // Create a new Viper instance and set a boolean value - v := viper.New() - v.Set("boolKey", true) - - // Initialize ViperConfig with the Viper instance - vc := NewViperConfig(v) - - // Test the GetBool method - value := vc.GetBool("boolKey") - if value != true { - t.Errorf("Expected true, got %v", value) - } -} - -func TestInitializeViperConfig(t *testing.T) { - // Since InitializeViperConfig reads from a file, we'll create a temporary config file for testing - configName := "testconfig" - configType := "yaml" - configPath := "." - - // Create a temporary config file with some test data - testConfigContent := ` -stringKey: stringValue -intKey: 42 -boolKey: true -` - // Write the test config content to a temporary file - configFileName := configName + "." + configType - err := writeTempConfigFile(configFileName, testConfigContent) - if err != nil { - t.Fatalf("Failed to write temp config file: %v", err) - } - defer removeTempConfigFile(configFileName) - - // Initialize the config - cfg := InitializeViperConfig(configName, configType, configPath) - - // Test the values - if cfg.Get("stringKey") != "stringValue" { - t.Errorf("Expected 'stringValue', got '%s'", cfg.Get("stringKey")) - } - if cfg.GetInt("intKey") != 42 { - t.Errorf("Expected 42, got %d", cfg.GetInt("intKey")) - } - if cfg.GetBool("boolKey") != true { - t.Errorf("Expected true, got %v", cfg.GetBool("boolKey")) - } -} - -func writeTempConfigFile(filename, content string) error { - return os.WriteFile(filename, []byte(content), 0644) -} - -func removeTempConfigFile(filename string) { - os.Remove(filename) -} diff --git a/pkg/db/pgvector.go b/pkg/db/pgvector.go index beb4135..445da21 100644 --- a/pkg/db/pgvector.go +++ b/pkg/db/pgvector.go @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package db provides database-related functionality for the application. +// It includes implementations for vector storage and retrieval using PostgreSQL +// with the pgvector extension, enabling efficient similarity search operations +// on high-dimensional vector data. package db import ( @@ -19,10 +23,9 @@ import ( "fmt" "strings" + "github.com/google/uuid" "github.com/jackc/pgx/v4/pgxpool" "github.com/pgvector/pgvector-go" - - "github.com/stackloklabs/gollm/pkg/logger" ) // PGVector represents a connection to a PostgreSQL database with pgvector extension. @@ -86,8 +89,6 @@ func (pg *PGVector) SaveEmbedding(ctx context.Context, docID string, embedding [ // Log the error for debugging purposes return fmt.Errorf("failed to insert document: %w", err) } - - logger.Infof("Document inserted successfully: %s", docID) return nil } @@ -165,3 +166,33 @@ func ConvertEmbeddingToPGVector(embedding []float32) string { } return fmt.Sprintf("{%s}", strings.Join(strValues, ",")) } + +// CombineQueryWithContext combines the query and retrieved documents' content to provide context for generation. +func CombineQueryWithContext(query string, docs []Document) string { + var contextStr string + for _, doc := range docs { + // Cast doc.Metadata["content"] to a string + if content, ok := doc.Metadata["content"].(string); ok { + contextStr += content + "\n" + } + } + return fmt.Sprintf("Context: %s\nQuery: %s", contextStr, query) +} + +// InsertDocument insert a document into the vector store +func InsertDocument(ctx context.Context, vectorDB *PGVector, content string, embedding []float32) error { + // Generate a unique document ID (for simplicity, using a static value for testing) + docID := fmt.Sprintf("doc-%s", uuid.New().String()) + + // Create metadata + metadata := map[string]interface{}{ + "content": content, + } + + // Save the document and its embedding into the vector store + err := vectorDB.SaveEmbedding(ctx, docID, embedding, metadata) + if err != nil { + return fmt.Errorf("error saving embedding: %v", err) + } + return nil +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go deleted file mode 100644 index 6736fad..0000000 --- a/pkg/logger/logger.go +++ /dev/null @@ -1,61 +0,0 @@ -package logger - -import ( - "os" - - "github.com/rs/zerolog" - "github.com/stackloklabs/gollm/pkg/config" -) - -var logger zerolog.Logger - -func InitLogger() { - cfg := config.InitializeViperConfig("config", "yaml", ".") - // Set the time format for logs - zerolog.TimeFieldFormat = zerolog.TimeFormatUnix - - // Retrieve log level from configuration (using viper as an example) - log_level := cfg.Get("log_level") - - // Parse log level - level, err := zerolog.ParseLevel(log_level) - if err != nil { - level = zerolog.InfoLevel // Default to InfoLevel - } - - logger = zerolog.New(os.Stdout).With().Timestamp().Logger() - zerolog.SetGlobalLevel(level) -} - -// Info logs an info level message -func Info(msg string) { - logger.Info().Msg(msg) -} - -// Infof logs an info level message with formatting -func Infof(format string, v ...interface{}) { - logger.Info().Msgf(format, v...) -} - -// Debug logs a debug level message -func Debug(msg string) { - logger.Debug().Msg(msg) -} - -// Fatal logs a fatal level message and then exits the program -func Fatal(msg string) { - logger.Fatal().Msg(msg) -} - -// Fatalf logs a fatal level message with formatting and then exits the program -func Fatalf(format string, v ...interface{}) { - logger.Fatal().Msgf(format, v...) -} - -func Errorf(format string, v ...interface{}) { - logger.Error().Msgf(format, v...) -} - -func Error(msg string) { - logger.Error().Msg(msg) -}