import { MetricName } from "@cur8/measurements";
import { skipToken, useQuery } from "@tanstack/react-query";
import { APIClient } from "lib/api/client";
import { queryMetrics } from "lib/api/metrics";
import { CountryCode } from "lib/country";
import { useAPIClient } from "render/context/APIContext";
import { useSession } from "render/context/MSALContext";
import { useLatestVisitCountry } from "render/hooks/api/queries/useLatestVisitCountry";

function queryKey(metric: MetricName) {
  return ["report-metrics", metric];
}

function fetchMetric<T extends MetricName>({
  metricName,
  api,
  patientId,
  signal,
  countryCode,
}: {
  metricName: T;
  api: APIClient;
  patientId: string;
  signal?: AbortSignal;
  countryCode: CountryCode;
}) {
  return queryMetrics(
    api.measurement,
    { countryCode, metricName, patientId },
    { signal }
  );
}

export function useRiskMetric<M extends MetricName>(metricName: M) {
  const { data: countryCode } = useLatestVisitCountry();
  const { patientId } = useSession();
  const api = useAPIClient();
  return useQuery({
    enabled: !!countryCode,
    queryFn:
      countryCode &&
      (({ signal }) =>
        fetchMetric({
          api,
          countryCode,
          metricName,
          patientId,
          signal,
        })),
    queryKey: queryKey(metricName),
    staleTime: Infinity,
  });
}

export function useSkinLesionCountMetric({
  scanId,
  scanVersion,
}: {
  scanId?: string;
  scanVersion?: string;
}) {
  const { patientId } = useSession();
  const api = useAPIClient();
  return useQuery({
    queryFn:
      scanId && scanVersion
        ? ({ signal }) => {
            const req = api.detectedMarkings.getCount({
              patientId,
              scanId: scanId,
              scanVersion: scanVersion,
            });

            signal?.addEventListener("abort", () => req.abandon());

            return req.result;
          }
        : skipToken,
    queryKey: ["report-metrics", "skin-lesion-count"],
    staleTime: Infinity,
  });
}
