一、需求 基于 LangChain 和 Streamlit 的 Web 应用,用于使用 LLM 和嵌入从 SQLite 数据库中搜索相关的 offer。用户可以输入与品牌、类别或零售商相关的搜索查询,也支持通过 SQL 语句进行搜索,应用程序将从数据库中检索并显示相关的 offer。该应用使用 OpenAI API 进行自然语言处理和嵌入生成。
SQLite 官网:https://www.sqlite.org/pragma.html#toc
SQLite 使用手册:https://www.runoob.com/sqlite/sqlite-select.html
二、方法
目标 :该方法的目标是基于产品类别、品牌或零售商查询从 offer_retailer
表中提取相关的offer。鉴于所需数据分散在 data
目录中的多个表中,采用了语言模型(LLM)来促进智能数据库查询。
数据库准备 :最初,使用存储在 data
目录中的 .csv
文件构建了一个本地 SQLite 数据库。这是通过 sqlite3
和 pandas
库实现的。
LLM 集成 :通过 langchain_experimental.sql.SQLDatabaseChain
实现了语言模型(LLM)与本地 SQLite 数据库的有效交互。
提示工程 :该方法的一个重要方面是制定合适的提示,以指导 LLM 最佳地检索和格式化数据库条目。通过多次迭代和实验来微调这个提示。
相似度评分 :为了确定检索结果与查询的相关性,进行了余弦相似度比较。使用 langchain_openai.OpenAIEmbeddings
生成嵌入进行比较,从而对结果进行排序。
Streamlit 集成 :最后一步是解析 LLM 的输出,并围绕它构建一个用户友好的 Streamlit 应用,允许用户进行交互式搜索。
三、环境 在开始之前,请确保满足以下要求:
安装所需的包:
1 pip install -r requirements.txt
确保您的 SQLite 数据库已设置好,并包含必要的表(brand_category
,categories
,offer_retailer
)。
注意:streamlit版本需要<1.30,一般为1.29.0,否则启动会报以下错误。
四、代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 import osimport sqlite3import pandas as pdimport streamlit as stfrom llm import RetrievalLLMDATA_PATH = 'data' TABLES = ('brand_category' , 'categories' , 'offer_retailer' ) DB_NAME = 'offer_db.sqlite' PROMPT_TEMPLATE = """ 你会接收到一个查询,你的任务是从`offer_retailer`表中的`OFFER`字段检索相关offer。 查询可能是混合大小写的,所以也要搜索大写版本的查询。 重要的是,你可能需要使用数据库中其他表的信息,即:`brand_category`, `categories`, `offer_retailer`,来检索正确的offer。 不要虚构offer。如果在`offer_retailer`表中找不到offer,返回字符串:`NONE`。 如果你能从`offer_retailer`表中检索到offer,用分隔符`#`分隔每个offer。例如,输出应该是这样的:`offer1#offer2#offer3`。 如果SQLResult为空,返回`None`。不要生成任何offer。 这是查询:`{}` """ st.title("搜索offer 🔍" ) conn = sqlite3.connect('offer_db.sqlite' ) def is_sql_query (query ): sql_keywords = ['SELECT' , 'INSERT' , 'UPDATE' , 'DELETE' , 'CREATE' , 'DROP' , 'ALTER' ,'TRUNCATE' , 'MERGE' , 'CALL' , 'EXPLAIN' , 'DESCRIBE' , 'SHOW' ] query_upper = query.strip().upper() for keyword in sql_keywords: if query_upper.startswith(keyword): return True sql_pattern = re.compile (r'^\s*(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER|TRUNCATE|MERGE|CALL|EXPLAIN|DESCRIBE|SHOW)\s+' , re.IGNORECASE ) if sql_pattern.match (query): return True return False with st.form("search_form" ): query = st.text_input("通过类别、品牌或发布商搜索offer。" ) submitted = st.form_submit_button("搜索" ) retrieval_llm = RetrievalLLM( data_path=DATA_PATH, tables=TABLES, db_name=DB_NAME, openai_api_key=os.getenv('OPENAI_API_KEY' ), ) if submitted: if is_sql_query(query): st.write(pd.read_sql_query(query, conn)) else : retrieved_offers = retrieval_llm.retrieve_offers( PROMPT_TEMPLATE.format (query) ) if not retrieved_offers: st.text("未找到相关offer。" ) else : st.table(retrieval_llm.parse_output(retrieved_offers, query))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 import sqlite3import numpy as npimport pandas as pdfrom langchain_openai import OpenAIEmbeddingsfrom langchain_openai import OpenAIfrom langchain_community.utilities import SQLDatabasefrom langchain_experimental.sql import SQLDatabaseChainclass RetrievalLLM : """一个类,用于使用大型语言模型(LLM)检索和重新排序offer。 参数: data_path (str): 包含数据CSV文件的目录路径。 tables (list[str]): 数据CSV文件的名称列表。 db_name (str): 用于存储数据的SQLite数据库名称。 openai_api_key (str): OpenAI API密钥。 属性: data_path (str): 包含数据CSV文件的目录路径。 tables (list[str]): 数据CSV文件的名称列表。 db_name (str): 用于存储数据的SQLite数据库名称。 openai_api_key (str): OpenAI API密钥。 db (SQLDatabase): SQLite数据库连接。 llm (OpenAI): OpenAI LLM客户端。 embeddings (OpenAIEmbeddings): OpenAI嵌入客户端。 db_chain (SQLDatabaseChain): 与LLM集成的SQL数据库链。 """ def init (self, data_path, tables, db_name, openai_api_key ): self .data_path = data_path self .tables = tables self .db_name = db_name self .openai_api_key = openai_api_key dfs = {} for table in self .tables: dfs[table] = pd.read_csv(f"{self.data_path} /{table} .csv" ) with sqlite3.connect(self .db_name) as local_db: for table, df in dfs.items(): df.to_sql(table, local_db, if_exists="replace" ) self .db = SQLDatabase.from_uri(f"sqlite:///{self.db_name} " ) self .llm = OpenAI( temperature=0 , verbose=True , openai_api_key=self .openai_api_key ) self .embeddings = OpenAIEmbeddings(openai_api_key=self .openai_api_key) self .db_chain = SQLDatabaseChain.from_llm(self .llm, self .db) self .allow_reuse = True def retrieve_offers (self, prompt ): """使用LLM从数据库中检索offer。 参数: prompt (str): 用于检索offer的提示。 返回: list[str]: 检索到的offer列表。 """ retrieved_offers = self .db_chain.run(prompt) return None if retrieved_offers == "None" else retrieved_offers def get_embeddings (self, documents ): """使用LLM获取文档的嵌入。 参数: documents (list[str]): 文档列表。 返回: np.ndarray: 包含文档嵌入的NumPy数组。 """ if len (documents) == 1 : return np.asarray(self .embeddings.embed_query(documents[0 ])) else : embeddings_list = [] for document in documents: embeddings_list.append(self .embeddings.embed_query(document)) return np.asarray(embeddings_list) def parse_output (self, retrieved_offers, query ): """解析retrieve_offers()方法的输出并返回一个数据帧。 参数: retrieved_offers (list[str]): 检索到的offer列表。 query (str): 用于检索offer的查询。 返回: pd.DataFrame: 包含匹配相似度和offer的数据帧。 """ top_offers = retrieved_offers.split("#" ) query_embedding = self .get_embeddings([query]) offer_embeddings = self .get_embeddings(top_offers) sim_scores = np.dot(offer_embeddings, query_embedding.T).flatten() sim_scores = [p * 100 for p in sim_scores] df = ( pd.DataFrame({"匹配相似度 %" : sim_scores, "offer" : top_offers}) .sort_values(by=["匹配相似度 %" ], ascending=False ) .reset_index(drop=True ) ) df.index += 1 return df
五、运行 本地运行应用
应用运行后,打开浏览器并导航到 http://localhost:8501
访问offer搜索界面。
在文本输入框中输入您的搜索查询(品牌、类别或零售商)。
点击“搜索”按钮启动搜索。
匹配查询的相关 offer 将以表格形式显示。
六、问答效果 问题1:select * from categories
问题2:select CATEGORY_ID from categories
问题3:RED GOLD