import {
  getClaimAnalyticsAgentNode,
  getClaimAnalyticsAgentToolNode,
  getRagAgentNode,
  getRagAgentToolNode,
  getSupervisorAgentNode,
} from "duck/graph/nodes";
import { graphState, GraphStateType } from "duck/graph/state";
import { RequestParams } from "duck/graph/types";
import { Runnable } from "@langchain/core/runnables";
import { END, MemorySaver, START, StateGraph } from "@langchain/langgraph/web";

import { shouldContinue } from "./utils";

/**
 * @summary Get an agent executor for the claim analytics expert.
 * @param params The parameters for the agent from the UI
 * @param withMemory True to use the memory saver, false to not use memory at all
 * @returns The agent to process the user utterance
 */
const getGraph = async (
  params: RequestParams,
  withMemory: boolean = false
): Promise<Runnable> => {
  // Create agent executor, optionally with a checkpointer
  // The MemorySaver checkpointer is meant for experimentation and is not intended for production usage
  const checkpointer = withMemory ? new MemorySaver() : undefined;

  /**
   * State Graph Visualization:
   *
   * START
   *   |
   *   v
   * SUPERVISOR
   *   |
   *   v
   * +-------------------+
   * | Conditional Edges |
   * |-------------------|
   * | rag               |
   * | claimAnalytics    |
   * +-------------------+
   *   |                   |
   *   v                   v
   * RAG                CLAIM_ANALYTICS
   *   |                   |
   *   v                   v
   * +-------------------+ +-------------------+
   * | Conditional Edges | | Conditional Edges |
   * |-------------------| |-------------------|
   * | ragTools          | | claimAnalyticsTools|
   * | END               | | END               |
   * +-------------------+ +-------------------+
   *   |                   |
   *   v                   v
   * RAG_TOOLS         CLAIM_ANALYTICS_TOOLS
   *   |                   |
   *   v                   v
   * RAG                CLAIM_ANALYTICS
   *   |                   |
   *   v                   v
   * END                 END
   */
  const stateGraph = new StateGraph(graphState)
    .addNode("supervisor", await getSupervisorAgentNode())
    .addNode("rag", await getRagAgentNode(params))
    .addNode("ragTools", getRagAgentToolNode())
    .addNode("claimAnalytics", await getClaimAnalyticsAgentNode(params))
    .addNode("claimAnalyticsTools", getClaimAnalyticsAgentToolNode())
    .addEdge(START, "supervisor")
    .addConditionalEdges("supervisor", (x: GraphStateType) => x.next, {
      rag: "rag",
      claimAnalytics: "claimAnalytics",
    })
    .addConditionalEdges("rag", shouldContinue, {
      tools: "ragTools",
      [END]: END,
    })
    .addEdge("ragTools", "rag")
    .addConditionalEdges("claimAnalytics", shouldContinue, {
      tools: "claimAnalyticsTools",
      [END]: END,
    })
    .addEdge("claimAnalyticsTools", "claimAnalytics");

  // compile the state graph with checkpointer
  const app = stateGraph.compile({
    checkpointer,
  });

  return app;
};

export default getGraph;
