Routing Requests

In this lesson, you will use LangGraph to create a conditional chain that will choose whether to execute a vector search or the Cypher QA Chain.

src/modules/agent/nodes/talks.ts
typescript
import { Neo4jVectorStore } from "@langchain/community/vectorstores/neo4j_vector";
import { StringOutputParser } from "@langchain/core/output_parsers";
import { OpenAIEmbeddings } from "@langchain/openai";
import { ChatOpenAI } from "@langchain/openai";
import { z } from "zod";
import {
  ChatPromptTemplate,
  HumanMessagePromptTemplate,
  MessagesPlaceholder,
  SystemMessagePromptTemplate,
} from "@langchain/core/prompts";
import {
  RunnablePassthrough,
  RunnablePick,
  RunnableSequence,
} from "@langchain/core/runnables";
import { DynamicStructuredTool } from "langchain/tools";

export async function initRetrievalChain() {
  // Specify embedding model
  const embeddings = new OpenAIEmbeddings({
    openAIApiKey: process.env.OPEN_AI_API_KEY,
  });

  // Create vector store
  const store = await Neo4jVectorStore.fromExistingGraph(embeddings, {
    url: process.env.NEO4J_URI,
    username: process.env.NEO4J_USERNAME,
    password: process.env.NEO4J_PASSWORD,
    nodeLabel: "Talk",
    textNodeProperties: ["title", "description"],
    indexName: "talk_embeddings_openai",
    embeddingNodeProperty: "embedding",
    retrievalQuery: `
RETURN node.description AS text, score,
node {
    .time, .title,
    url: 'https://athens.cityjsconf.org/'+ node.url,
    speaker: [
    (node)-[:GIVEN_BY]->(s) |
    s { .name, .company, .x_handle, .bio }
    ][0],
    room: [ (node)-[:IN_ROOM]->(r) | r.name ][0],
    tags: [ (node)-[:HAS_TAG]->(t) | t.name ]

} AS metadata
`,
  });

  // Retrieve Documents from Vector Index
  const retriever = store.asRetriever();

  // 1. create a prompt template
  const prompt = ChatPromptTemplate.fromMessages([
    SystemMessagePromptTemplate.fromTemplate(
      `You are a helpful assistant helping users with queries
      about the CityJS Athens conference.
      Answer the user's question to the best of your ability.
      If you do not know the answer, just say you don't know.
      `
    ),
    SystemMessagePromptTemplate.fromTemplate(
      `
      Here are some talks to help you answer the question.
      Don't use your pre-trained knowledge to answer the question.
      Always include a full link to the meetup.
      If the answer isn't included in the documents, say you don't know.

      Documents:
      {documents}
    `
    ),
    HumanMessagePromptTemplate.fromTemplate(`Question: {message}`),
  ]);

  // 2. choose an LLM
  const llm = new ChatOpenAI({
    openAIApiKey: process.env.OPENAI_API_KEY,
    temperature: 0.9,
  });

  // 3. parse the response
  const parser = new StringOutputParser();

  // 4. runnable sequence (LCEL)
  const chain = RunnableSequence.from<
    { message: string; documents?: string },
    string
  >([
    RunnablePassthrough.assign({
      documents: new RunnablePick("message").pipe(
        retriever.pipe((docs) => JSON.stringify(docs))
      ),
    }),
    prompt,
    llm,
    parser,
  ]);

  return chain;
}
src/modules/agent/nodes/database.ts
typescript
import { Neo4jGraph } from "@langchain/community/graphs/neo4j_graph";
import { ChatOpenAI } from "@langchain/openai";
import { GraphCypherQAChain } from "@langchain/community/chains/graph_qa/cypher";
import { DynamicStructuredTool } from "langchain/tools";

export async function initCypherQAChain() {
  const llm = new ChatOpenAI({ model: "gpt-4-turbo" });
  const graph = await Neo4jGraph.initialize({
    url: process.env.NEO4J_URI as string,
    username: process.env.NEO4J_USERNAME as string,
    password: process.env.NEO4J_PASSWORD as string,
    database: process.env.NEO4J_DATABASE as string | undefined,
    enhancedSchema: true,
  });

  const chain = GraphCypherQAChain.fromLLM({
    graph,
    llm,
    returnDirect: true,
  });

  return chain;
}
src/modules/agent/nodes/router.ts
typescript
import { RunnableConfig } from "@langchain/core/runnables";
import { AgentState } from "../constants";
import { StructuredOutputParser } from "@langchain/core/output_parsers";
import { z } from "zod";
import { PromptTemplate } from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import {
  NODE_DATABASE_QUERY,
  NODE_JOKE,
  NODE_TALK_RETRIEVER,
} from "../constants";

export const router = async (data: AgentState, config?: RunnableConfig) => {
  const prompt = PromptTemplate.fromTemplate(`
    You are an AI agent deciding which tool to use.

    Follow the rules below to come to your conclusion:

    * If the question relates to the description of a talk and can be answered with
    the contents of the talk title or description, respond with "${NODE_TALK_RETRIEVER}"
    * If the question relates to Talks, Spekaers and can be answered by a database
    query, respond with "${NODE_DATABASE_QUERY}".
    * For all other queries, respond with "${NODE_JOKE}".

    Question: {question}

    {format_instructions}
  `);
  const llm = new ChatOpenAI({
    openAIApiKey: process.env.OPENAI_API_KEY,
  });
  const parser = StructuredOutputParser.fromZodSchema(
    z.enum([NODE_TALK_RETRIEVER, NODE_DATABASE_QUERY, NODE_JOKE])
  );

  const chain = prompt.pipe(llm).pipe(parser);

  return chain.invoke({
    question: data.rephrased,
    format_instructions: parser.getFormatInstructions(),
  });
};
src/modules/agent/constants.ts
typescript
import { BaseMessage } from "@langchain/core/messages";

export type AgentState = {
  input: string;
  rephrased: string;
  messages: BaseMessage[];
  output: string;
  log: string[];
};

export const NODE_REPHRASE = "rephrase";
export const NODE_ROUTER = "router";
export const NODE_TALK_RETRIEVER = "talk";
export const NODE_DATABASE_QUERY = "database";
export const NODE_JOKE = "joke";
src/modules/agent/nodes/rephrase.ts
typescript
import { getHistory } from "../../../modules/agent/history";
import { StringOutputParser } from "@langchain/core/output_parsers";
import {
  ChatPromptTemplate,
  HumanMessagePromptTemplate,
  MessagesPlaceholder,
  SystemMessagePromptTemplate,
} from "@langchain/core/prompts";
import { RunnableConfig, RunnableSequence } from "@langchain/core/runnables";
import { AgentState } from "../constants";
import { ChatOpenAI } from "@langchain/openai";

export const rephraseQuestion = async (
  data: AgentState,
  config?: RunnableConfig
) => {
  const llm = new ChatOpenAI({ temperature: 0 });
  const history = await getHistory(config?.configurable?.sessionId, 5);

  const rephrase = ChatPromptTemplate.fromMessages([
    SystemMessagePromptTemplate.fromTemplate(`
        Use the following conversation history to rephrase the input
        into a standalone question.
      `),
    new MessagesPlaceholder("history"),
    HumanMessagePromptTemplate.fromTemplate(`Input: {input}`),
  ]);

  const rephraseChain = RunnableSequence.from([
    rephrase,
    llm,
    new StringOutputParser(),
  ]);

  const rephrased = await rephraseChain.invoke({
    history,
    input: data.input,
  });

  console.log({
    input: data.input,
    rephrased,
  });

  return {
    history,
    rephrased,
  };
};
src/modules/agent/index.ts
typescript
import { BaseMessage } from "@langchain/core/messages";
import { END, START, StateGraph, StateGraphArgs } from "@langchain/langgraph";
import { ChatOpenAI } from "@langchain/openai";
import { AgentState } from "./constants";
import { rephraseQuestion } from "./nodes/rephrase";
import { router } from "./nodes/router";
import {
  NODE_DATABASE_QUERY,
  NODE_JOKE,
  NODE_REPHRASE,
  NODE_TALK_RETRIEVER,
} from "./constants";
import { initRetrievalChain } from "./workflows/talks";
import { initCypherQAChain } from "./workflows/database";
import { tellJoke } from "./nodes/joke";

const agentState: StateGraphArgs<AgentState>["channels"] = {
  input: null,
  rephrased: null,
  messages: null,
  output: null,
  log: {
    value: (x: string[], y: string[]) => x.concat(y),
    default: () => [],
  },
};

/**
- router
- conditional:
 */

export async function buildLangGraphAgent() {
  const talkChain = await initRetrievalChain();
  const databaseChain = await initCypherQAChain();

  const model = new ChatOpenAI({
    temperature: 0,
  });

  const graph = new StateGraph({
    channels: agentState,
  })

    // 1. Get conversation history and rephrase the question
    .addNode(NODE_REPHRASE, rephraseQuestion)
    .addEdge(START, NODE_REPHRASE)

    // 2. route the request
    .addConditionalEdges(NODE_REPHRASE, router)

    // 3. Call Vector tool
    .addNode(NODE_TALK_RETRIEVER, async (data: AgentState) => {
      const output = await talkChain.invoke({ message: data.input });
      return { output };
    })
    .addEdge(NODE_TALK_RETRIEVER, END)

    // 4. Call CypherQAChain
    .addNode(NODE_DATABASE_QUERY, async (data: AgentState) => {
      // TODO: type error in export
      const output = (await databaseChain.invoke({
        query: data.input,
      })) as unknown as string;

      return { output };
    })
    .addEdge(NODE_DATABASE_QUERY, END)

    // 5. Tell a joke
    .addNode(NODE_JOKE, tellJoke)
    .addEdge(NODE_JOKE, END);

  const app = await graph.compile();

  return app;
}

export async function call(input: string, sessionId?: string) {
  const agent = await buildLangGraphAgent();

  const res = await agent.invoke({ input }, { configurable: { sessionId } });

  return res.output;
}

Verify Challenge

Summary

Well done!