Create a retriever

To incorporate a retriever and Neo4j vector into a LangChain application, you can create a retrieval chain.

Retrieval chain

The Neo4jVector class has a as_retriever() method that returns a retriever.

By incorporating Neo4jVector into a RetrievalQA chain, you can use data and vectors in Neo4j in a LangChain application.

Review this program incorporating the moviePlots vector index into a retrieval chain.

python
import os
from dotenv import load_dotenv
load_dotenv()

from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_neo4j import Neo4jGraph, Neo4jVector

llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'))

embedding_provider = OpenAIEmbeddings(
    openai_api_key=os.getenv('OPENAI_API_KEY')
)

graph = Neo4jGraph(
    url=os.getenv('NEO4J_URI'),
    username=os.getenv('NEO4J_USERNAME'),
    password=os.getenv('NEO4J_PASSWORD'),
)

movie_plot_vector = Neo4jVector.from_existing_index(
    embedding_provider,
    graph=graph,
    index_name="moviePlots",
    embedding_node_property="plotEmbedding",
    text_node_property="plot",
)

plot_retriever = RetrievalQA.from_llm(
    llm=llm,
    retriever=movie_plot_vector.as_retriever()
)

response = plot_retriever.invoke(
    {"query": "A movie where a mission to the moon goes wrong"}
)

print(response)

When the program runs, the RetrievalQA chain will use the movie_plot_vector retriever to retrieve documents from the moviePlots index and pass them to the chat_llm language model.

Understanding the results

It can be difficult to understand how the model generated the response and how the retriever affected it.

By setting the optional verbose and return_source_documents arguments to True when creating the RetrievalQA chain, you can see the source documents and the retriever’s score for each document.

python
plot_retriever = RetrievalQA.from_llm(
    llm=chat_llm,
    retriever=movie_plot_vector.as_retriever(),
    verbose=True,
    return_source_documents=True
)

Agent

You can add the plot_retriever chain as a tool to the chat_agent.py program you created earlier. The agent can use the chain to find similar movie plots.

To complete this optional challenge, you will need to update the 2-llm-rag-python-langchain/chat_agent.py program to:

  1. Create the Neo4jVector from the moviePlots vector index.

  2. Create the RetrievalQA chain using the Neo4jVector as the retriever.

  3. Update the tools to use the RetrievalQA chain.

You may need to change the name and description of the tools so the LLM can distinguish between them.

Click to reveal the solution

There is no right or wrong way to complete this challenge. Here is one potential solution.

python
import os
from dotenv import load_dotenv
load_dotenv()

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.agents import AgentExecutor, create_react_agent
from langchain.tools import Tool
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.schema import StrOutputParser
from langchain_community.tools import YouTubeSearchTool
from langchain_neo4j import Neo4jChatMessageHistory, Neo4jGraph, Neo4jVector
from uuid import uuid4

SESSION_ID = str(uuid4())
print(f"Session ID: {SESSION_ID}")

llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'))

embedding_provider = OpenAIEmbeddings(
    openai_api_key=os.getenv('OPENAI_API_KEY')
)

graph = Neo4jGraph(
    url=os.getenv('NEO4J_URI'),
    username=os.getenv('NEO4J_USERNAME'),
    password=os.getenv('NEO4J_PASSWORD'),
)

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a movie expert. You find movies from a genre or plot.",
        ),
        ("human", "{input}"),
    ]
)

movie_chat = prompt | llm | StrOutputParser()

youtube = YouTubeSearchTool()

movie_plot_vector = Neo4jVector.from_existing_index(
    embedding_provider,
    graph=graph,
    index_name="moviePlots",
    embedding_node_property="plotEmbedding",
    text_node_property="plot",
)

plot_retriever = RetrievalQA.from_llm(
    llm=llm,
    retriever=movie_plot_vector.as_retriever()
)

def get_memory(session_id):
    return Neo4jChatMessageHistory(session_id=session_id, graph=graph)

def call_trailer_search(input):
    input = input.replace(",", " ")
    return youtube.run(input)

tools = [
    Tool.from_function(
        name="Movie Chat",
        description="For when you need to chat about movies. The question will be a string. Return a string.",
        func=movie_chat.invoke,
    ),
    Tool.from_function(
        name="Movie Trailer Search",
        description="Use when needing to find a movie trailer. The question will include the word trailer. Return a link to a YouTube video.",
        func=call_trailer_search,
    ),
    Tool.from_function(
        name="Movie Plot Search",
        description="For when you need to compare a plot to a movie. The question will be a string. Return a string.",
        func=plot_retriever.invoke,
    ),
]

agent_prompt = hub.pull("hwchase17/react-chat")
agent = create_react_agent(llm, tools, agent_prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)

chat_agent = RunnableWithMessageHistory(
    agent_executor,
    get_memory,
    input_messages_key="input",
    history_messages_key="chat_history",
)

while True:
    q = input("> ")

    response = chat_agent.invoke(
        {
            "input": q
        },
        {"configurable": {"session_id": SESSION_ID}},
    )
    
    print(response["output"])

Continue

When you are ready, you can move on to the next task.

Lesson Summary

You learned how to create a retriever chain and to incorporate it into a LangChain application.

Next you will learn about Cypher generation.