| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Command ragserver is an HTTP server that implements RAG (Retrieval |
| // Augmented Generation) using the Gemini model and Weaviate, which |
| // are accessed using LangChainGo. See the accompanying README file for |
| // additional details. |
| package main |
| |
| import ( |
| "cmp" |
| "context" |
| "fmt" |
| "log" |
| "net/http" |
| "os" |
| "strings" |
| |
| "github.com/tmc/langchaingo/embeddings" |
| "github.com/tmc/langchaingo/llms" |
| "github.com/tmc/langchaingo/llms/googleai" |
| "github.com/tmc/langchaingo/schema" |
| "github.com/tmc/langchaingo/vectorstores/weaviate" |
| ) |
| |
| const generativeModelName = "gemini-1.5-flash" |
| const embeddingModelName = "text-embedding-004" |
| |
| // This is a standard Go HTTP server. Server state is in the ragServer struct. |
| // The `main` function connects to the required services (Weaviate and Google |
| // AI), initializes the server state and registers HTTP handlers. |
| func main() { |
| ctx := context.Background() |
| apiKey := os.Getenv("GEMINI_API_KEY") |
| geminiClient, err := googleai.New(ctx, |
| googleai.WithAPIKey(apiKey), |
| googleai.WithDefaultEmbeddingModel(embeddingModelName)) |
| if err != nil { |
| log.Fatal(err) |
| } |
| |
| emb, err := embeddings.NewEmbedder(geminiClient) |
| if err != nil { |
| log.Fatal(err) |
| } |
| |
| wvStore, err := weaviate.New( |
| weaviate.WithEmbedder(emb), |
| weaviate.WithScheme("http"), |
| weaviate.WithHost("localhost:"+cmp.Or(os.Getenv("WVPORT"), "9035")), |
| weaviate.WithIndexName("Document"), |
| ) |
| |
| server := &ragServer{ |
| ctx: ctx, |
| wvStore: wvStore, |
| geminiClient: geminiClient, |
| } |
| |
| mux := http.NewServeMux() |
| mux.HandleFunc("POST /add/", server.addDocumentsHandler) |
| mux.HandleFunc("POST /query/", server.queryHandler) |
| |
| port := cmp.Or(os.Getenv("SERVERPORT"), "9020") |
| address := "localhost:" + port |
| log.Println("listening on", address) |
| log.Fatal(http.ListenAndServe(address, mux)) |
| } |
| |
| type ragServer struct { |
| ctx context.Context |
| wvStore weaviate.Store |
| geminiClient *googleai.GoogleAI |
| } |
| |
| func (rs *ragServer) addDocumentsHandler(w http.ResponseWriter, req *http.Request) { |
| // Parse HTTP request from JSON. |
| type document struct { |
| Text string |
| } |
| type addRequest struct { |
| Documents []document |
| } |
| ar := &addRequest{} |
| |
| err := readRequestJSON(req, ar) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusBadRequest) |
| return |
| } |
| |
| // Store documents and their embeddings in weaviate |
| var wvDocs []schema.Document |
| for _, doc := range ar.Documents { |
| wvDocs = append(wvDocs, schema.Document{PageContent: doc.Text}) |
| } |
| _, err = rs.wvStore.AddDocuments(rs.ctx, wvDocs) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| } |
| |
| func (rs *ragServer) queryHandler(w http.ResponseWriter, req *http.Request) { |
| // Parse HTTP request from JSON. |
| type queryRequest struct { |
| Content string |
| } |
| qr := &queryRequest{} |
| err := readRequestJSON(req, qr) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusBadRequest) |
| return |
| } |
| |
| // Find the most similar documents. |
| docs, err := rs.wvStore.SimilaritySearch(rs.ctx, qr.Content, 3) |
| if err != nil { |
| http.Error(w, fmt.Errorf("similarity search: %w", err).Error(), http.StatusInternalServerError) |
| return |
| } |
| var docsContents []string |
| for _, doc := range docs { |
| docsContents = append(docsContents, doc.PageContent) |
| } |
| |
| // Creata a RAG query for the LLM with the most relevant documents as |
| // context. |
| ragQuery := fmt.Sprintf(ragTemplateStr, qr.Content, strings.Join(docsContents, "\n")) |
| respText, err := llms.GenerateFromSinglePrompt(rs.ctx, rs.geminiClient, ragQuery, llms.WithModel(generativeModelName)) |
| if err != nil { |
| log.Printf("calling generative model: %v", err.Error()) |
| http.Error(w, "generative model error", http.StatusInternalServerError) |
| return |
| } |
| |
| renderJSON(w, respText) |
| } |
| |
| const ragTemplateStr = ` |
| I will ask you a question and will provide some additional context information. |
| Assume this context information is factual and correct, as part of internal |
| documentation. |
| If the question relates to the context, answer it using the context. |
| If the question does not relate to the context, answer it as normal. |
| |
| For example, let's say the context has nothing in it about tropical flowers; |
| then if I ask you about tropical flowers, just answer what you know about them |
| without referring to the context. |
| |
| For example, if the context does mention minerology and I ask you about that, |
| provide information from the context along with general knowledge. |
| |
| Question: |
| %s |
| |
| Context: |
| %s |
| ` |