To incorporate a retriever and Neo4j vector into a LangChain application, you can create a retrieval chain.
Retrieval chain
The Neo4jVector
class has a as_retriever()
method that returns a retriever.
By incorporating Neo4jVector
into a RetrievalQA
chain, you can use data and vectors in Neo4j in a LangChain application.
Review this program incorporating the moviePlots
vector index into a retrieval chain.
import os
from dotenv import load_dotenv
load_dotenv()
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_neo4j import Neo4jGraph, Neo4jVector
llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'))
embedding_provider = OpenAIEmbeddings(
openai_api_key=os.getenv('OPENAI_API_KEY')
)
graph = Neo4jGraph(
url=os.getenv('NEO4J_URI'),
username=os.getenv('NEO4J_USERNAME'),
password=os.getenv('NEO4J_PASSWORD'),
)
movie_plot_vector = Neo4jVector.from_existing_index(
embedding_provider,
graph=graph,
index_name="moviePlots",
embedding_node_property="plotEmbedding",
text_node_property="plot",
)
plot_retriever = RetrievalQA.from_llm(
llm=llm,
retriever=movie_plot_vector.as_retriever()
)
response = plot_retriever.invoke(
{"query": "A movie where a mission to the moon goes wrong"}
)
print(response)
When the program runs, the RetrievalQA
chain will use the movie_plot_vector
retriever to retrieve documents from the moviePlots
index and pass them to the chat_llm
language model.
Understanding the results
It can be difficult to understand how the model generated the response and how the retriever affected it.
By setting the optional verbose
and return_source_documents
arguments to True
when creating the RetrievalQA
chain, you can see the source documents and the retriever’s score for each document.
plot_retriever = RetrievalQA.from_llm(
llm=chat_llm,
retriever=movie_plot_vector.as_retriever(),
verbose=True,
return_source_documents=True
)
Agent
You can add the plot_retriever
chain as a tool to the chat_agent.py
program you created earlier.
The agent can use the chain to find similar movie plots.
To complete this optional challenge, you will need to update the 2-llm-rag-python-langchain/chat_agent.py
program to:
-
Create the
Neo4jVector
from themoviePlots
vector index. -
Create the
RetrievalQA
chain using theNeo4jVector
as the retriever. -
Update the
tools
to use theRetrievalQA
chain.
You may need to change the name
and description
of the tools
so the LLM can distinguish between them.
Click to reveal the solution
There is no right or wrong way to complete this challenge. Here is one potential solution.
import os
from dotenv import load_dotenv
load_dotenv()
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.agents import AgentExecutor, create_react_agent
from langchain.tools import Tool
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.schema import StrOutputParser
from langchain_community.tools import YouTubeSearchTool
from langchain_neo4j import Neo4jChatMessageHistory, Neo4jGraph, Neo4jVector
from uuid import uuid4
SESSION_ID = str(uuid4())
print(f"Session ID: {SESSION_ID}")
llm = ChatOpenAI(openai_api_key=os.getenv('OPENAI_API_KEY'))
embedding_provider = OpenAIEmbeddings(
openai_api_key=os.getenv('OPENAI_API_KEY')
)
graph = Neo4jGraph(
url=os.getenv('NEO4J_URI'),
username=os.getenv('NEO4J_USERNAME'),
password=os.getenv('NEO4J_PASSWORD'),
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a movie expert. You find movies from a genre or plot.",
),
("human", "{input}"),
]
)
movie_chat = prompt | llm | StrOutputParser()
youtube = YouTubeSearchTool()
movie_plot_vector = Neo4jVector.from_existing_index(
embedding_provider,
graph=graph,
index_name="moviePlots",
embedding_node_property="plotEmbedding",
text_node_property="plot",
)
plot_retriever = RetrievalQA.from_llm(
llm=llm,
retriever=movie_plot_vector.as_retriever()
)
def get_memory(session_id):
return Neo4jChatMessageHistory(session_id=session_id, graph=graph)
def call_trailer_search(input):
input = input.replace(",", " ")
return youtube.run(input)
tools = [
Tool.from_function(
name="Movie Chat",
description="For when you need to chat about movies. The question will be a string. Return a string.",
func=movie_chat.invoke,
),
Tool.from_function(
name="Movie Trailer Search",
description="Use when needing to find a movie trailer. The question will include the word trailer. Return a link to a YouTube video.",
func=call_trailer_search,
),
Tool.from_function(
name="Movie Plot Search",
description="For when you need to compare a plot to a movie. The question will be a string. Return a string.",
func=plot_retriever.invoke,
),
]
agent_prompt = hub.pull("hwchase17/react-chat")
agent = create_react_agent(llm, tools, agent_prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools)
chat_agent = RunnableWithMessageHistory(
agent_executor,
get_memory,
input_messages_key="input",
history_messages_key="chat_history",
)
while True:
q = input("> ")
response = chat_agent.invoke(
{
"input": q
},
{"configurable": {"session_id": SESSION_ID}},
)
print(response["output"])
Continue
When you are ready, you can move on to the next task.
Lesson Summary
You learned how to create a retriever chain and to incorporate it into a LangChain application.
Next you will learn about Cypher generation.