import { CHAT_END_EVENT } from "duck/context/constants";
import { GPT4O_MODEL_SPEC } from "duck/graph/constants";
import loadPrompt from "duck/graph/nodes/loadPrompt";
import { promptNames } from "duck/graph/nodes/types";
import { NodeOutputType, NodeType } from "duck/graph/nodes/utils";
import { graphState, GraphStateType } from "duck/graph/state";
import { DuckGraphParams } from "duck/graph/types";
import { drawGraph, getLLM, NodeNames } from "duck/graph/utils";
import { AIMessageChunk } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";
import {
  CompiledGraph,
  END,
  START,
  StateGraph,
} from "@langchain/langgraph/web";

import {
  filterBuilderQueryToFilterBuilderState,
  getFiltersQuery,
  mergeFilterGroupStates,
} from "features/ui/Filters/FilterBuilder/utils";

import {
  CLAIM_LIMIT,
  ClaimsAgentSubgraphNodeNames,
  DEFAULT_SORT,
} from "./constants";
import {
  countClaims,
  getExploreInClaimAnalyticsUriTemplate,
  getGeneralConfig,
  listClaims,
} from "./utils";

const getLoadClaimsNode = async (
  params: DuckGraphParams
): Promise<NodeType> => {
  const name = ClaimsAgentSubgraphNodeNames.LOAD_CLAIMS;
  const setEphemeralMessage = params.uiHandlers.setEphemeralMessage;

  return async (): Promise<NodeOutputType> => {
    console.debug(`LoadClaimsAgent`);

    const {
      claimsFilterQueryString,
      claimsTableFilterQueryString,
      vehiclesFilterQueryString,
      claimsTableSortQueryString,
    } = params.currentState.claimAnalytics;
    const sort = claimsTableSortQueryString || DEFAULT_SORT;

    setEphemeralMessage("Loading claims...");
    console.debug(`[${name}]: Loading claims with filters:`, {
      claimsFilterQueryString,
      vehiclesFilterQueryString,
      claimsTableSortQueryString,
    });

    const generalConfig = await getGeneralConfig();

    const combinedClaimsFilterState = mergeFilterGroupStates(
      filterBuilderQueryToFilterBuilderState(claimsFilterQueryString),
      filterBuilderQueryToFilterBuilderState(claimsTableFilterQueryString)
    );
    const combinedClaimsFilterQueryString = getFiltersQuery(
      combinedClaimsFilterState
    );

    const claims = await listClaims(
      {
        filter: combinedClaimsFilterQueryString,
        vehiclesFilter: vehiclesFilterQueryString,
        sort,
        mileageUnit: generalConfig.data.mileageUnit,
        analytics: true,
      },
      CLAIM_LIMIT
    );

    const claimsToSummarize = claims.map((claim) => ({
      externalID: claim.externalID,
      date: claim.date,
      mileage: claim.mileage,
      mileageUnit: claim.mileageUnit,
      costTotal: claim.costTotal,
      parts: claim.parts,
      notesTechnician: claim.notesTechnician,
      notesCustomer: claim.notesCustomer,
      notes: claim.notes,
      group: claim.group,
      subgroup: claim.subgroup,
      laborCodeInformation: claim.laborCodeInformation,
      failedPartInformation: claim.failedPartInformation,
    }));

    const countResponse = await countClaims({
      filter: combinedClaimsFilterQueryString,
      vehiclesFilter: vehiclesFilterQueryString,
      analytics: true,
    });
    const totalCount = countResponse.data.count;
    console.debug(
      `[${name}]: Loaded ${claims.length} claims, total count: ${totalCount}\n${JSON.stringify(
        claimsToSummarize[0],
        null,
        2
      )}`
    );

    return {
      claimsAgent: {
        claimsList: claimsToSummarize,
        claimsCount: totalCount,
      },
    };
  };
};

const getSummarizeClaimsNode = async (
  params: DuckGraphParams
): Promise<NodeType> => {
  const prompt = await loadPrompt(promptNames.SUMMARIZE_CLAIMS_AGENT);

  const llm = getLLM(GPT4O_MODEL_SPEC);

  const agent = prompt.pipe(llm);

  const name = ClaimsAgentSubgraphNodeNames.SUMMARIZE_CLAIMS;
  const setEphemeralMessage = params.uiHandlers.setEphemeralMessage;

  return async (
    state: GraphStateType,
    config: RunnableConfig = {}
  ): Promise<NodeOutputType> => {
    console.debug(`SummarizeClaimsAgent`);

    const { userInput } = state;
    const { claimsList, claimsCount } = state.claimsAgent;

    if (!claimsList || !claimsCount) {
      throw new Error(
        `[${name}]: Missing claimsList, or claimsCount in state: ${JSON.stringify(
          state
        )}`
      );
    }

    const { claimsTableSortQueryString } = params.currentState.claimAnalytics;
    const sort = claimsTableSortQueryString || DEFAULT_SORT;

    setEphemeralMessage("Summarizing claims...");
    const exploreInClaimAnalyticsUriTemplate =
      getExploreInClaimAnalyticsUriTemplate();

    // Function to run the streaming process with a given claim limit.
    const streamAgent = async (limit: number) => {
      const agentInput = {
        claimsSort: sort,
        totalCount: claimsCount,
        maxLimit: limit,
        claimsToSummarize: claimsList.slice(0, limit),
        exploreInClaimAnalyticsUriTemplate,
        userInput,
        messages: [],
      };

      const stream = agent.streamEvents(agentInput, {
        version: "v2",
        ...config,
      });

      let finalMessage: AIMessageChunk | undefined;
      for await (const event of stream) {
        params.uiHandlers.handleStreamEvent(event);

        if (event.event === CHAT_END_EVENT) {
          finalMessage = event.data.output;
        }
      }

      console.debug(`[${name}]: Final response:`, finalMessage);

      return finalMessage;
    };

    // Run the streaming process.
    let agentMessage: AIMessageChunk | undefined;
    try {
      agentMessage = await streamAgent(CLAIM_LIMIT);
    } catch (error: any) {
      console.error(`[${name}]: Error processing invocation:`, error);

      // Check if the error is a 400 Bad Request.
      if (error.code === 400 || error.status === 400) {
        const smallerLimit = Math.floor(CLAIM_LIMIT / 2);
        setEphemeralMessage(
          `Something went wrong, retrying on a smaller set...`
        );
        console.debug(
          `[${name}]: Received 400 error. Retrying with claim limit: ${smallerLimit}`
        );
        agentMessage = await streamAgent(smallerLimit);
      } else {
        // Re-throw any other errors.
        throw error;
      }
    }

    if (agentMessage !== undefined) {
      return {
        messages: [agentMessage],
      };
    }

    return {};
  };
};

// Making this agent it's own graph as it will evolve into a more complex agent
const getGraph = async (params: DuckGraphParams) => {
  const subgraphBuilder = new StateGraph(graphState)
    .addNode(
      ClaimsAgentSubgraphNodeNames.LOAD_CLAIMS,
      await getLoadClaimsNode(params)
    )
    .addNode(
      ClaimsAgentSubgraphNodeNames.SUMMARIZE_CLAIMS,
      await getSummarizeClaimsNode(params)
    )

    .addEdge(START, ClaimsAgentSubgraphNodeNames.LOAD_CLAIMS)
    .addEdge(
      ClaimsAgentSubgraphNodeNames.LOAD_CLAIMS,
      ClaimsAgentSubgraphNodeNames.SUMMARIZE_CLAIMS
    )
    .addEdge(ClaimsAgentSubgraphNodeNames.SUMMARIZE_CLAIMS, END);

  const graph = subgraphBuilder.compile();

  if (params.captureImages) {
    await drawGraph(
      graph as CompiledGraph<any>,
      `${NodeNames.CLAIMS}_graph.png`
    );
  }

  return graph;
};

export default getGraph;
