spRAG是什么

  spRAG是一个针对非结构化数据的检索引擎。它特别擅长处理对密集文本的复杂查询,比如财务报告、法律文件和学术论文。有两种关键方法用于提高性能,超越了普通的RAG系统:

自动上下文(AutoContext):自动上下文的实现相当直接,首先生成文档的1-2句话摘要,将文件名添加进去,然后在嵌入之前将这些内容添加到每个文本块的前面。
相关段落提取(Relevant Segment Extraction, RSE):相关段落提取(RSE)是一个后处理步骤,它将相关文本块的集群智能地组合成长文本段落,我们称之为段落。这些段落比任何单独的文本块都能为LLM提供更好的上下文。对于简单的事实问题,答案通常包含在单个文本块中;但对于更复杂的问题,答案通常跨越更长的文本段落。RSE的目标是智能地识别提供最相关信息的文本段落,而不受固定长度文本块的限制。

spRAG使用

  spRAG的使用非常简单,安装sprag后,只需下面简单的几行代码,即可基于pdf文档内容进行问题检索。

from sprag.create_kb import create_kb_from_file

file_path = "spRAG/tests/data/levels_of_agi.pdf"
kb_id = "levels_of_agi"
kb = create_kb_from_file(kb_id, file_path)
search_queries = ["What are the levels of AGI?", "What is the highest level of AGI?"]
results = kb.query(search_queries)
for segment in results:
    print(segment)

    默认情况下,spRAG使用OpenAI进行嵌入,使用Claude 3 Haiku进行自动上下文处理,并使用Cohere进行重新排序。所以,如果要使用spRAG需要同时申请这三种大模型的key。如果只用OpenAI需要编写更多的代码,具体如下所示:下面的代码中通过OpenAIChatAPI来设置自动文本提取的llm使用openAI,重新排序这里,设置的NoReranker,经过下面的修改,即便只有OPENAI的key,也能正常使用spRAG了。

from sprag.llm import OpenAIChatAPI
from sprag.reranker import NoReranker

llm = OpenAIChatAPI(model='gpt-3.5-turbo')
reranker = NoReranker()

kb = KnowledgeBase(kb_id="levels_of_agi", reranker=reranker, auto_context_model=llm)

spRAG实现原理

  实际spRAG的source code如下图所示,并不复杂,首先看vector_db部分,vector_db_connectors文件夹里面是直接使用第三方vectorDB工具weavite,对文档进行向量存储,添加,删除,以及查询等操作。外面的vector_db.py也定义了增加,删除文档,检索内容等方法,这个module会被vector_db_connectors中的python文件继承。

    在vector_db.py问中,封装了两种search方法,第一种是使用余玄相识度来进行内容检索,方法二是调用knn中的exhuastive_search方法来进行检索。

  以上就是向量存储方法的内容,接着看Ebedding的source code,这里没有啥复杂度,直接调用的各个LLM提供的SDK,生成向量即可。以下图为例,这里调用OpenAI提供的SDK来对输入的内容生成向量。

class OpenAIEmbedding(Embedding):
    def __init__(self, model: str = "text-embedding-3-small", dimension: int = 768):
        """
        Only v3 models are supported.
        """
        super().__init__(dimension)
        self.model = model
        self.client = OpenAI()

    def get_embeddings(self, text, input_type=None):
        response = self.client.embeddings.create(input=text, model=self.model, dimensions=int(self.dimension))
        embeddings = [embedding_item.embedding for embedding_item in response.data]
        return embeddings[0] if isinstance(text, str) else embeddings
    
    def to_dict(self):
        base_dict = super().to_dict()
        base_dict.update({
            'model': self.model
        })
        return base_dict

  接着来看看比较关键的自动上下文提取是如何实现的。下面的代码就是auto_context的source code,这段代码使用LLM生成简明扼要的文档描述,描述包含文档的标题和内容概要,并确保处理大文档时内容不会超出限制。具体代码内容,如下所示:

from sprag.llm import LLM
import tiktoken

PROMPT = """
INSTRUCTIONS
What is the following document, and what is it about? 

Your response should be a single sentence, and it shouldn't be an excessively long sentence. DO NOT respond with anything else.

You MUST include the name of the document in your response (if available), as that is a critical piece of information. Be as specific and detailed as possible in your document name. You can even include things like the author's name or the date of publication if that information is available. DO NOT just use the filename as the document name. It needs to be a descriptive and human-readable name.

Your response should take the form of "This document is: X, and is about: Y". For example, if the document is a book about the history of the United States called A People's History of the United States, your response might be "This document is: A People's History of the United States, and is about the history of the United States, covering the period from 1776 to the present day." If the document is the 2023 Form 10-K for Apple Inc., your response might be "This document is: Apple Inc. FY2023 Form 10-K, and is about: the financial performance and operations of Apple Inc. during the fiscal year 2023."

{auto_context_guidance}

{truncation_message}

DOCUMENT
filename: {document_title}

{document}
""".strip()

TRUNCATION_MESSAGE = """
Also note that the document text provided below is just the first ~4500 words of the document. Your response should still pertain to the entire document, not just the text provided below.
""".strip()

def truncate_content(content: str, max_tokens: int):
    TOKEN_ENCODER = tiktoken.encoding_for_model('gpt-3.5-turbo')
    tokens = TOKEN_ENCODER.encode(content, disallowed_special=())
    truncated_tokens = tokens[:max_tokens]
    return TOKEN_ENCODER.decode(truncated_tokens), min(len(tokens), max_tokens)

def get_document_context(auto_context_model: LLM, text: str, document_title: str, auto_context_guidance: str = ""):
    # truncate the content if it's too long
    max_content_tokens = 6000 # if this number changes, also update the truncation message above
    text, num_tokens = truncate_content(text, max_content_tokens)
    if num_tokens < max_content_tokens:
        truncation_message = ""
    else:
        truncation_message = TRUNCATION_MESSAGE
    
    # get document context
    prompt = PROMPT.format(auto_context_guidance=auto_context_guidance, document=text, document_title=document_title, truncation_message=truncation_message)
    chat_messages = [{"role": "user", "content": prompt}]
    document_context = auto_context_model.make_llm_call(chat_messages)
    return document_context

def get_chunk_header(file_name, document_context):
    chunk_header = f"Document context: the following excerpt is from {file_name}. {document_context}"
    return chunk_header

    以上是一些比较关键的source code简要说明,接下来看看如何从调用入口开始,看看加载文件后,如何完成对文档的检索的。代码的source code封装的方法create_kb_from_file,就是使用spRAG的入口方法。可看到这个方法中,根据文件的后缀进行了不同的处理,总体而言,通过这样的处理,让spRAG可以支持多种文件格式。kb这个对象实际就是KnowlegeBase这个class的实例化。

def create_kb_from_file(kb_id: str, file_path: str, title: str = None, description: str = "", language: str = 'en', auto_context: bool = True, auto_context_guidance: str = ""):
    """
    - kb_id is the name of the knowledge base
    - file_path is the absolute path to the file containing the documents

    Supported file types: .docx, .md, .txt, .pdf
    """
    if not title:
        title = kb_id
    
    # create a new KB
    kb = KnowledgeBase(kb_id, title=title, description=description, language=language, exists_ok=False)
    
    print (f'Creating KB with id {kb_id}...')

    file_name = os.path.basename(file_path)

    # add document
    if file_path.endswith(('.docx', '.md', '.txt', '.pdf')):
        # define clean file path as just the file name here since we're not using a directory
        clean_file_path = file_name
        
        if file_path.endswith('.docx'):
            text = extract_text_from_docx(file_path)
        elif file_name.endswith('.pdf'):
            text = extract_text_from_pdf(file_path)
        elif file_path.endswith('.md') or file_path.endswith('.txt'):
            with open(file_path, 'r') as f:
                text = f.read()

        kb.add_document(clean_file_path, text, auto_context=auto_context, auto_context_guidance=auto_context_guidance)
    else:
        print (f"Unsupported file type: {file_name}")
        return
    
    return kb

KnowlegeBase中包含两个重要方法

  add_document方法是将一个新文档添加到知识库中。首先验证参数的有效性,确保auto_context和chunk_header不同时设置,并检查文档ID的唯一性。如果启用了自动上下文生成,会通过LLM生成文档上下文描述并设置小块头信息。接着,它将文档文本拆分为小块,添加头信息后生成嵌入向量,然后将小块及其嵌入向量添加到数据库中,最后保存数据库以确保数据持久化。也就是前面提到的很关键的自动化上下文能力,具体的source code如下所示:

def add_document(self, doc_id: str, text: str, auto_context: bool = True, chunk_header: str = None, auto_context_guidance: str = ""):
        # verify that only one of auto_context and chunk_header is set
        try:
            assert auto_context != (chunk_header is not None)
        except:
            print ("Error in add_document: only one of auto_context and chunk_header can be set")

        # verify that the document does not already exist in the KB
        if doc_id in self.chunk_db.get_all_doc_ids():
            print (f"Document with ID {doc_id} already exists in the KB. Skipping...")
            return
        
        # AutoContext
        if auto_context:
            document_context = get_document_context(self.auto_context_model, text, document_title=doc_id, auto_context_guidance=auto_context_guidance)
            chunk_header = get_chunk_header(file_name=doc_id, document_context=document_context)
        elif chunk_header:
            pass
        else:
            chunk_header = ""

        chunks = self.split_into_chunks(text)
        print (f'Adding {len(chunks)} chunks to the database')

        # add chunk headers to the chunks before embedding them
        chunks_to_embed = []
        for i, chunk in enumerate(chunks):
            chunk_to_embed = f'[{chunk_header}]\n{chunk}'
            chunks_to_embed.append(chunk_to_embed)

        # embed the chunks
        if len(chunks) <= 50:
            # if the document is short, we can get all the embeddings at once
            chunk_embeddings = self.get_embeddings(chunks_to_embed, input_type="document")
        else:
            # if the document is long, we need to get the embeddings in chunks
            chunk_embeddings = []
            for i in range(0, len(chunks), 50):
                chunk_embeddings += self.get_embeddings(chunks_to_embed[i:i+50], input_type="document")

        assert len(chunks) == len(chunk_embeddings) == len(chunks_to_embed)
        self.chunk_db.add_document(doc_id, {i: {'chunk_text': chunk, 'chunk_header': chunk_header} for i, chunk in enumerate(chunks)})

        # create metadata list
        metadata = []
        for i, chunk in enumerate(chunks):
            metadata.append({'doc_id': doc_id, 'chunk_index': i, 'chunk_header': chunk_header, 'chunk_text': chunk})

        # add the vectors and metadata to the vector database
        self.vector_db.add_vectors(vectors=chunk_embeddings, metadata=metadata)

        self.save() # save the database to disk after adding a document

  再来看看search方法,search方法总体比较简单,首先就是调用常规的vector_db.search进行查询,获取到多个结果,再将问题和首次查询到的结果传递给rerank_search,进行二次查询,从而提高查询的准确率。

def search(self, query: str, top_k: int) -> list:
        """
        Get top k most relevant chunks for a given query. This is where we interface with the vector database.
        - returns a list of dictionaries, where each dictionary has the following keys: `metadata` (which contains 'doc_id', 'chunk_index', 'chunk_text', and 'chunk_header') and `similarity`
        """
        query_vector = self.get_embeddings(query, input_type="query") # embed the query
        search_results = self.vector_db.search(query_vector, top_k) # do a vector database search
        search_results = self.reranker.rerank_search_results(query, search_results) # rerank search results using a reranker
        return search_results

  以上就是对开源框架spRAG的学习小结。更多内容可查看官网

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部