import { formatDocs, NodeType } from "duck/graph/nodes/utils";
import { GraphStateType } from "duck/graph/state";
import { getLLM } from "duck/graph/utils";
import { JsonOutputToolsParser } from "@langchain/core/output_parsers/openai_tools";

import getPrompt from "./getPrompt";

export const NEXT_NODE_OPTIONS = ["claimAnalytics", "rag"];

const toolDef = {
  type: "function",
  function: {
    name: "route",
    description: "Select the next node based on the last message",
    parameters: {
      title: "routeSchema",
      type: "object",
      properties: {
        next: {
          title: "Next",
          anyOf: [{ enum: NEXT_NODE_OPTIONS }],
        },
      },
      required: ["next"],
      additionalProperties: false,
    },
  },
} as const;

const getNode = async (): Promise<NodeType> => {
  const prompt = await getPrompt();
  const llm = getLLM();
  const validateLLM = llm.bindTools([toolDef], {
    tool_choice: { type: "function", function: { name: "route" } },
    strict: true,
  });

  const validateChain = prompt
    .pipe(validateLLM)
    .pipe(new JsonOutputToolsParser())
    // select the first one
    .pipe((x: any) => x[0].args);

  return async (state: GraphStateType) => {
    const response = await validateChain.invoke({
      messages: state.messages,
      current_state: JSON.stringify(state.pageState),
      context: formatDocs(state.documents),
    });
    return response;
  };
};

export default getNode;
