import { GPT4OMINI_MODEL_SPEC } from "duck/graph/constants";
import loadPrompt from "duck/graph/nodes/loadPrompt";
import { PromptName, promptNames } from "duck/graph/nodes/types";
import {
  createAgent,
  getAgentNodes,
  getEphemeralMessageForNode,
  NodeOutputType,
  NodeType,
} from "duck/graph/nodes/utils";
import { PageHandlerRoute } from "duck/graph/PageHandler/types";
import { graphState, GraphStateType } from "duck/graph/state";
import { DuckGraphParams } from "duck/graph/types";
import { drawGraph, getLLM, NodeNames, NodeNamesType } from "duck/graph/utils";
import { captureScreenshot } from "duck/ui/utils";
import { AIMessage, RemoveMessage } from "@langchain/core/messages";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { RunnableConfig } from "@langchain/core/runnables";
import { StructuredTool } from "@langchain/core/tools";
import {
  CompiledGraph,
  END,
  START,
  StateGraph,
} from "@langchain/langgraph/web";

import { cloneObject } from "shared/utils";

import { PageAgentSubgraphNodeNames } from "./constants";

const getExtractFromScreenshotNode = async (
  params: DuckGraphParams
): Promise<NodeType> => {
  const prompt = await loadPrompt(
    promptNames.PAGE_AGENT_ANALYZE_SCREENSHOT_AGENT
  );

  const modelSpec = cloneObject(GPT4OMINI_MODEL_SPEC);
  modelSpec.maxTokens = 100;
  const llm = getLLM(GPT4OMINI_MODEL_SPEC);

  const agent = createAgent(llm, [], prompt, { strict: true });

  const name = NodeNames.ANALYZE_SCREENSHOT;
  const setEphemeralMessage = params.uiHandlers.setEphemeralMessage;

  return async (
    { userInput }: GraphStateType,
    config: RunnableConfig = {}
  ): Promise<NodeOutputType> => {
    setEphemeralMessage(getEphemeralMessageForNode(name));
    console.debug(`AnalyzeScreenshotAgent: ${name}`);

    const screenshot = await captureScreenshot();

    const agentMessage = await agent.invoke(
      {
        userInput,
        screenshot,
      },
      config
    );
    agentMessage.name = name;

    return {
      messages: agentMessage,
    };
  };
};

const deleteRouterAIMessageNode = ({
  messages,
}: GraphStateType): NodeOutputType => {
  const lastMessage = messages[messages.length - 1];

  // if last message is an AIMessage with name router remove it
  if (
    lastMessage instanceof AIMessage &&
    lastMessage.name === "router" &&
    lastMessage.id
  ) {
    return {
      messages: [new RemoveMessage({ id: lastMessage.id })],
    };
  }

  return {};
};

interface BuildPageAgentSubgraphParams {
  params: DuckGraphParams;
  retrievalTools: StructuredTool[];
  actionTools: StructuredTool[];
  promptName: PromptName;
  name: NodeNamesType;
  route?: PageHandlerRoute;
}

const buildPageAgentSubgraph = async ({
  params,
  retrievalTools,
  actionTools,
  promptName,
  name,
  route,
}: BuildPageAgentSubgraphParams) => {
  const prompt = await loadPrompt(promptName);
  const retrieveInfoPrompt = ChatPromptTemplate.fromMessages([
    ...prompt.promptMessages,
    [
      "system",
      "your responsibility is to retrieve information required to fulfill the user's request. Make sure you call all the tools required to retrieve the information in parallel.",
    ],
  ]);
  const { node: agentNode, toolNode: agentToolNode } = await getAgentNodes({
    params,
    tools: actionTools,
    prompt,
    name,
    route,
  });

  const subgraphBuilder = new StateGraph(graphState)
    .addNode(
      PageAgentSubgraphNodeNames.DELETE_ROUTER_AI_MESSAGE,
      deleteRouterAIMessageNode
    )
    .addNode(
      PageAgentSubgraphNodeNames.EXTRACT_FROM_SCREENSHOT,
      await getExtractFromScreenshotNode(params)
    )
    .addNode(PageAgentSubgraphNodeNames.AGENT, agentNode)
    .addNode(PageAgentSubgraphNodeNames.AGENT_TOOL, agentToolNode)

    .addEdge(START, PageAgentSubgraphNodeNames.DELETE_ROUTER_AI_MESSAGE)
    .addEdge(
      PageAgentSubgraphNodeNames.DELETE_ROUTER_AI_MESSAGE,
      PageAgentSubgraphNodeNames.EXTRACT_FROM_SCREENSHOT
    )
    .addEdge(
      PageAgentSubgraphNodeNames.AGENT,
      PageAgentSubgraphNodeNames.AGENT_TOOL
    )
    .addEdge(PageAgentSubgraphNodeNames.AGENT_TOOL, END);

  if (retrievalTools.length > 0) {
    const { node: retrieveInfoNode, toolNode: retrieveInfoToolNode } =
      await getAgentNodes({
        params,
        tools: retrievalTools,
        prompt: retrieveInfoPrompt,
        name: PageAgentSubgraphNodeNames.RETRIEVE_INFO,
      });

    subgraphBuilder
      .addNode(PageAgentSubgraphNodeNames.RETRIEVE_INFO, retrieveInfoNode)
      .addNode(
        PageAgentSubgraphNodeNames.RETRIEVE_INFO_TOOL,
        retrieveInfoToolNode
      )
      .addEdge(
        PageAgentSubgraphNodeNames.DELETE_ROUTER_AI_MESSAGE,
        PageAgentSubgraphNodeNames.RETRIEVE_INFO
      )
      .addEdge(
        PageAgentSubgraphNodeNames.RETRIEVE_INFO,
        PageAgentSubgraphNodeNames.RETRIEVE_INFO_TOOL
      )
      .addEdge(
        [
          PageAgentSubgraphNodeNames.EXTRACT_FROM_SCREENSHOT,
          PageAgentSubgraphNodeNames.RETRIEVE_INFO_TOOL,
        ],
        PageAgentSubgraphNodeNames.AGENT
      );
  } else {
    subgraphBuilder.addEdge(
      PageAgentSubgraphNodeNames.EXTRACT_FROM_SCREENSHOT,
      PageAgentSubgraphNodeNames.AGENT
    );
  }

  const graph = subgraphBuilder.compile();

  if (params.captureImages) {
    await drawGraph(graph as CompiledGraph<any>, `${name}_graph.png`);
  }

  return graph;
};

export default buildPageAgentSubgraph;
