import { ATTRIBUTE_SOURCE_KEYS, ATTRIBUTE_SOURCES } from "duck/graph/constants";
import { StringSetter } from "duck/graph/types";
import { z } from "zod";
import { tool } from "@langchain/core/tools";

import {
  vectorStoreSearch,
  VectorStoreSearchParameters,
} from "shared/api/vectorstore/api";

type AttributeSourceKey = keyof typeof ATTRIBUTE_SOURCE_KEYS;

const FilteringConfigSchema = z.object({
  contains: z.boolean(),
  decimalNumbers: z.boolean(),
  empty: z.boolean(),
  lowCardinality: z.boolean(),
  minMax: z.boolean(),
  negativeNumbers: z.boolean(),
  startsWith: z.boolean(),
});

const DataAttributeSchema = z.object({
  ID: z.string(),
  columnName: z.string().nullable(),
  description: z.string().nullable(),
  displayName: z.string().nullable(),
  filtering: z.boolean(),
  filteringConfig: FilteringConfigSchema,
  grouping: z.boolean(),
  nullable: z.boolean(),
  sorting: z.boolean(),
  type: z.string(),
  unitAtRest: z.string().nullable(),
});

type DataAttribute = z.infer<typeof DataAttributeSchema>;

const retrieveRelevantAttributes = async (
  query: string,
  source?: string,
  k?: number,
  distanceThreshold?: number,
  filtering?: boolean,
  sorting?: boolean,
  grouping?: boolean
): Promise<DataAttribute[]> => {
  console.debug("Retrieving relevant attributes", {
    query,
    source,
    k,
    distanceThreshold,
  });

  const params: VectorStoreSearchParameters = {
    query,
    k,
    distanceThreshold,
    source,
  };
  const { data } = await vectorStoreSearch(params);

  if (!data) {
    return [];
  }

  const attributes = data
    .reduce((acc: DataAttribute[], result) => {
      const { documentID, document, title, url, metadata } = result;
      console.log("Document retrieved:", {
        documentID,
        document,
        title,
        url,
        source,
        metadata,
      });

      try {
        const parsedDocument = JSON.parse(document);
        const attribute: DataAttribute =
          DataAttributeSchema.parse(parsedDocument);
        acc.push(attribute);
      } catch (error) {
        console.error("Failed to parse attribute document:", document, error);
      }

      return acc;
    }, [])
    .filter(
      (attribute) =>
        (filtering === undefined || attribute.filtering === filtering) &&
        (sorting === undefined || attribute.sorting === sorting) &&
        (grouping === undefined || attribute.grouping === grouping)
    );

  console.debug(
    `Retrieved ${attributes.length} relevant attributes sorted by relevance.`
  );

  return attributes;
};

const RetrieveAttributesSchema = z.object({
  attributeName: z.string().describe("Name of the attribute to retrieve."),
});

type RetrieveAttributes = z.infer<typeof RetrieveAttributesSchema>;

const getAttributeRetrievalTool = (
  sourceKey: AttributeSourceKey,
  k: number,
  distanceThreshold: number,
  setEphemeralMessage: StringSetter
) => {
  const source = ATTRIBUTE_SOURCE_KEYS[sourceKey];
  const retrieveAttributes = async ({ attributeName }: RetrieveAttributes) => {
    setEphemeralMessage(`Retrieving ${source} attributes`);
    console.debug(`Retrieving ${source} attributes`, { attributeName });

    const attributes = await retrieveRelevantAttributes(
      attributeName,
      ATTRIBUTE_SOURCES[source],
      k,
      distanceThreshold
    );

    if (attributes.length === 0) {
      return "No attributes found for the given attributeName.";
    }

    return `Retrieved ${attributes.length} relevant ${source} attributes sorted by relevance.\n\n${JSON.stringify(attributes, null, 2)}`;
  };

  return tool(retrieveAttributes, {
    name: `retrieve${source}Attributes`,
    description: `Call this tool to retrieve the relevant ${source} attributes given a attributeName query.`,
    schema: RetrieveAttributesSchema,
  });
};

export default getAttributeRetrievalTool;
