You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

412 lines
14 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "245ab07a-fb2f-4cf4-ab9a-5c05a9b44daa",
"metadata": {},
"source": [
"# LangChain retrieval knowledge base Q&A based on Qwen-7B-Chat"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e8df2cb7-a69c-4231-9596-4c871d893633",
"metadata": {},
"source": [
"This notebook introduces a question-answering application based on a local knowledge base using Qwen-7B-Chat with langchain. The goal is to establish a knowledge base Q&A solution that is friendly to many scenarios and open-source models, and that can run offline. The implementation process of this project includes loading files -> reading text -> segmenting text -> vectorizing text -> vectorizing questions -> matching the top k most similar text vectors with the question vectors -> incorporating the matched text as context along with the question into the prompt -> submitting to the LLM (Large Language Model) to generate an answer."
]
},
{
"cell_type": "markdown",
"id": "92e9c81a-45c7-4c12-91af-3c5dd52f63bb",
"metadata": {},
"source": [
"## Preparation"
]
},
{
"cell_type": "markdown",
"id": "84cfcf88-3bef-4412-a658-4eaefeb6502a",
"metadata": {},
"source": [
"Download Qwen-7B-Chat\n",
"\n",
"Firstly, we need to download the model. You can use the snapshot_download that comes with modelscope to download the model to a specified directory."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c1f9ded-8035-42c7-82c7-444ce06572bc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install modelscope"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c26225c-c958-429e-b81d-2de9820670c2",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"from modelscope.hub.snapshot_download import snapshot_download\n",
"snapshot_download(\"Qwen/Qwen-7B-Chat\",cache_dir='/tmp/models') "
]
},
{
"cell_type": "markdown",
"id": "e8f51796-49fa-467d-a825-ae9a281eb3fd",
"metadata": {},
"source": [
"Download the dependencies for langchain and Qwen."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87fe1023-644f-4610-afaf-0b7cddc30d60",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"!pip install langchain==0.0.187 dashscope==1.0.4 sentencepiece==0.1.99 cpm_kernels==1.0.11 nltk==3.8.1 sentence_transformers==2.2.2 unstructured==0.6.5 faiss-cpu==1.7.4 icetk==0.0.7"
]
},
{
"cell_type": "markdown",
"id": "853cdfa4-a2ce-4baa-919a-b9e2aecd2706",
"metadata": {},
"source": [
"Download the retrieval document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ba800dc-311d-4a83-8115-f05b09b39ffd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/qwen_recipes/LLM_Survey_Chinese.pdf.txt"
]
},
{
"cell_type": "markdown",
"id": "07e923b3-b7ae-4983-abeb-2ce115566f15",
"metadata": {},
"source": [
"Download the text2vec model, for Chinese in our case."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a07cd8d-3cec-40f6-8d2b-eb111aaf1164",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/qwen_recipes/GanymedeNil_text2vec-large-chinese.tar.gz\n",
"!tar -zxvf GanymedeNil_text2vec-large-chinese.tar.gz -C /tmp"
]
},
{
"cell_type": "markdown",
"id": "dc483af0-170e-4e61-8d25-a336d1592e34",
"metadata": {},
"source": [
"## Try out the model \n",
"\n",
"Load the Qwen-7B-Chat model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c112cf82-0447-46c4-9c32-18f243c0a686",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"from abc import ABC\n",
"from langchain.llms.base import LLM\n",
"from typing import Any, List, Mapping, Optional\n",
"from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_path=\"/tmp/models/Qwen/Qwen-7B-Chat\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
"model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda()\n",
"model.eval()\n",
"\n",
"class Qwen(LLM, ABC):\n",
" max_token: int = 10000\n",
" temperature: float = 0.01\n",
" top_p = 0.9\n",
" history_len: int = 3\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"Qwen\"\n",
"\n",
" @property\n",
" def _history_len(self) -> int:\n",
" return self.history_len\n",
"\n",
" def set_history_len(self, history_len: int = 10) -> None:\n",
" self.history_len = history_len\n",
"\n",
" def _call(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" ) -> str:\n",
" response, _ = model.chat(tokenizer, prompt, history=[])\n",
" return response\n",
" \n",
" @property\n",
" def _identifying_params(self) -> Mapping[str, Any]:\n",
" \"\"\"Get the identifying parameters.\"\"\"\n",
" return {\"max_token\": self.max_token,\n",
" \"temperature\": self.temperature,\n",
" \"top_p\": self.top_p,\n",
" \"history_len\": self.history_len}\n",
" "
]
},
{
"cell_type": "markdown",
"id": "382ed433-870f-424e-b074-210ea6f84b70",
"metadata": {},
"source": [
"Specify the txt file that needs retrieval for knowledge-based Q&A."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14be706b-4a7d-4906-9369-1f03c6c99854",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import argparse\n",
"from langchain.vectorstores import FAISS\n",
"from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
"from typing import List, Tuple\n",
"import numpy as np\n",
"from langchain.document_loaders import TextLoader\n",
"from chinese_text_splitter import ChineseTextSplitter\n",
"from langchain.docstore.document import Document\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"from langchain.chains import RetrievalQA\n",
"\n",
"\n",
"def load_file(filepath, sentence_size=100):\n",
" loader = TextLoader(filepath, autodetect_encoding=True)\n",
" textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)\n",
" docs = loader.load_and_split(textsplitter)\n",
" write_check_file(filepath, docs)\n",
" return docs\n",
"\n",
"\n",
"def write_check_file(filepath, docs):\n",
" folder_path = os.path.join(os.path.dirname(filepath), \"tmp_files\")\n",
" if not os.path.exists(folder_path):\n",
" os.makedirs(folder_path)\n",
" fp = os.path.join(folder_path, 'load_file.txt')\n",
" with open(fp, 'a+', encoding='utf-8') as fout:\n",
" fout.write(\"filepath=%s,len=%s\" % (filepath, len(docs)))\n",
" fout.write('\\n')\n",
" for i in docs:\n",
" fout.write(str(i))\n",
" fout.write('\\n')\n",
" fout.close()\n",
"\n",
" \n",
"def seperate_list(ls: List[int]) -> List[List[int]]:\n",
" lists = []\n",
" ls1 = [ls[0]]\n",
" for i in range(1, len(ls)):\n",
" if ls[i - 1] + 1 == ls[i]:\n",
" ls1.append(ls[i])\n",
" else:\n",
" lists.append(ls1)\n",
" ls1 = [ls[i]]\n",
" lists.append(ls1)\n",
" return lists\n",
"\n",
"\n",
"class FAISSWrapper(FAISS):\n",
" chunk_size = 250\n",
" chunk_conent = True\n",
" score_threshold = 0\n",
" \n",
" def similarity_search_with_score_by_vector(\n",
" self, embedding: List[float], k: int = 4\n",
" ) -> List[Tuple[Document, float]]:\n",
" scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k)\n",
" docs = []\n",
" id_set = set()\n",
" store_len = len(self.index_to_docstore_id)\n",
" for j, i in enumerate(indices[0]):\n",
" if i == -1 or 0 < self.score_threshold < scores[0][j]:\n",
" # This happens when not enough docs are returned.\n",
" continue\n",
" _id = self.index_to_docstore_id[i]\n",
" doc = self.docstore.search(_id)\n",
" if not self.chunk_conent:\n",
" if not isinstance(doc, Document):\n",
" raise ValueError(f\"Could not find document for id {_id}, got {doc}\")\n",
" doc.metadata[\"score\"] = int(scores[0][j])\n",
" docs.append(doc)\n",
" continue\n",
" id_set.add(i)\n",
" docs_len = len(doc.page_content)\n",
" for k in range(1, max(i, store_len - i)):\n",
" break_flag = False\n",
" for l in [i + k, i - k]:\n",
" if 0 <= l < len(self.index_to_docstore_id):\n",
" _id0 = self.index_to_docstore_id[l]\n",
" doc0 = self.docstore.search(_id0)\n",
" if docs_len + len(doc0.page_content) > self.chunk_size:\n",
" break_flag = True\n",
" break\n",
" elif doc0.metadata[\"source\"] == doc.metadata[\"source\"]:\n",
" docs_len += len(doc0.page_content)\n",
" id_set.add(l)\n",
" if break_flag:\n",
" break\n",
" if not self.chunk_conent:\n",
" return docs\n",
" if len(id_set) == 0 and self.score_threshold > 0:\n",
" return []\n",
" id_list = sorted(list(id_set))\n",
" id_lists = seperate_list(id_list)\n",
" for id_seq in id_lists:\n",
" for id in id_seq:\n",
" if id == id_seq[0]:\n",
" _id = self.index_to_docstore_id[id]\n",
" doc = self.docstore.search(_id)\n",
" else:\n",
" _id0 = self.index_to_docstore_id[id]\n",
" doc0 = self.docstore.search(_id0)\n",
" doc.page_content += \" \" + doc0.page_content\n",
" if not isinstance(doc, Document):\n",
" raise ValueError(f\"Could not find document for id {_id}, got {doc}\")\n",
" doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])\n",
" doc.metadata[\"score\"] = int(doc_score)\n",
" docs.append((doc, doc_score))\n",
" return docs\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" # load docs\n",
" filepath = 'LLM_Survey_Chinese.pdf.txt'\n",
" # LLM name\n",
" LLM_TYPE = 'qwen'\n",
" # Embedding model name\n",
" EMBEDDING_MODEL = 'text2vec'\n",
" # 基于上下文的prompt模版请务必保留\"{question}\"和\"{context_str}\"\n",
" PROMPT_TEMPLATE = \"\"\"已知信息:\n",
" {context_str} \n",
" 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}\"\"\"\n",
" # Embedding running device\n",
" EMBEDDING_DEVICE = \"cuda\"\n",
" # return top-k text chunk from vector store\n",
" VECTOR_SEARCH_TOP_K = 3\n",
" # 文本分句长度\n",
" SENTENCE_SIZE = 50\n",
" CHAIN_TYPE = 'stuff'\n",
" llm_model_dict = {\n",
" \"qwen\": QWen,\n",
" }\n",
" embedding_model_dict = {\n",
" \"text2vec\": \"/tmp/GanymedeNil_text2vec-large-chinese\",\n",
" }\n",
" print(\"loading model start\")\n",
" llm = llm_model_dict[LLM_TYPE]()\n",
" embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],model_kwargs={'device': EMBEDDING_DEVICE})\n",
" print(\"loading model done\")\n",
"\n",
" print(\"loading documents start\")\n",
" docs = load_file(filepath, sentence_size=SENTENCE_SIZE)\n",
" print(\"loading documents done\")\n",
"\n",
" print(\"embedding start\")\n",
" docsearch = FAISSWrapper.from_documents(docs, embeddings)\n",
" print(\"embedding done\")\n",
"\n",
" print(\"loading qa start\")\n",
" prompt = PromptTemplate(\n",
" template=PROMPT_TEMPLATE, input_variables=[\"context_str\", \"question\"]\n",
" )\n",
"\n",
" chain_type_kwargs = {\"prompt\": prompt, \"document_variable_name\": \"context_str\"}\n",
" qa = RetrievalQA.from_chain_type(\n",
" llm=llm,\n",
" chain_type=CHAIN_TYPE, \n",
" retriever=docsearch.as_retriever(search_kwargs={\"k\": VECTOR_SEARCH_TOP_K}), \n",
" chain_type_kwargs=chain_type_kwargs)\n",
" print(\"loading qa done\")\n",
"\n",
" query = \"大模型指令微调有好的策略?\" \n",
" print(qa.run(query))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}