import * as d3 from "d3";
import { useRef, useEffect, useState } from "react";
import useInteractiveLegend from "./UseInteractiveLegend";

function ROCCurve({
  model,
  model_type,
  figureTitle,
  className,
  curveColors,
  setSelectedGroup,
  dataGlobal,
}) {
  const d3Container = useRef();
  const [dimensions, setDimensions] = useState({
    width: 0,
    height: 0,
    margins: {},
  });

  const svg = d3.select(d3Container.current);

  function useROCCurve(svgRef) {
    useEffect(() => {
      // Define dimensions and margins
      const figuresDiv = document.getElementsByClassName("figures-container")[0];
      const width = figuresDiv.clientWidth * 0.48;
      const height = width;
      const margins = {
        top: .2 * height,
        right: 0.2 * width,
        bottom: 0.2 * height,
        left: 0.2 * width,
      };
      setDimensions({ width, height, margins });

      svgRef
        .attr("width", "100%")
        .attr("height", "100%")
        .attr("viewBox", [0, 0, width-10, height]);

      const tooltip = d3
        .select("body")
        .append("div")
        .attr(
          "class",
          "absolute opacity-0 bg-gray-700 text-white text-sm px-2 py-1 rounded pointer-events-none"
        );

      // Create scales
      const xScale = d3
        .scaleLinear()
        .domain([0, 1])
        .range([margins.left, width - margins.right]);
      const yScale = d3
        .scaleLinear()
        .domain([0, 1])
        .range([height - margins.bottom, margins.top]);

      // Axes
      const xAxis = svgRef
        .append("g")
        .attr("transform", `translate(0, ${height - margins.bottom})`)
        .attr("stroke-width", 2)
        .call(d3.axisBottom(xScale));

      xAxis
        .append("text")
        .attr("class", "axis-label")
        .attr("y", margins.bottom * 0.9)
        .attr(
          "x",
          (width - margins.left - margins.right) / 2 + margins.left
        )
        .attr("fill", "#000")
        .attr("text-anchor", "middle")
        .style("font-size", "18px")
        .text("False Positive Rate");

      const yAxis = svgRef
        .append("g")
        .attr("transform", `translate(${margins.left},0)`)
        .attr("stroke-width", 2)
        .call(d3.axisLeft(yScale));

      yAxis
        .append("text")
        .attr("class", "axis-label")
        .attr("y", -(margins.left * 0.7))
        .attr(
          "x",
          -(height - margins.bottom - margins.top) / .852 + margins.top
        )
        .attr("fill", "#000")
        .attr("text-anchor", "middle")
        .attr("transform", "rotate(-90)")
        .style("font-size", "18px")
        .text("True Positive Rate");

      // Chance line
      svgRef
        .append("line")
        .style("stroke", "black")
        .style("stroke-width", 2)
        .style("stroke-dasharray", "3, 3")
        .attr("x1", margins.left)
        .attr("y1", height - margins.bottom)
        .attr("x2", width - margins.right)
        .attr("y2", margins.top);

      // Create ROC line for each class in model
      let selectedPoint = { index: null, class_name: null };
      Object.keys(model.roc[model_type]).forEach((class_name) => {
        let data = model.roc[model_type][class_name];

        // Parse data
        data.fpr = data.fpr.map((x) => +x);
        data.tpr = data.tpr.map((y) => +y);
        data.tps = Array.from(
          { length: data.posneg_mat[0].length },
          (_, i) => data.posneg_mat.filter((row) => row[i] === 0).length
        );
        data.fps = Array.from(
          { length: data.posneg_mat[0].length },
          (_, i) => data.posneg_mat.filter((row) => row[i] === 2).length
        );
        data.tns = Array.from(
          { length: data.posneg_mat[0].length },
          (_, i) => data.posneg_mat.filter((row) => row[i] === 1).length
        );
        data.fns = Array.from(
          { length: data.posneg_mat[0].length },
          (_, i) => data.posneg_mat.filter((row) => row[i] === 3).length
        );
        data.ppv = data.tps.map((d, i) => d / (d + data.fps[i]));
        data.npv = data.tns.map((d, i) => d / (d + data.fns[i]));
        data.sensitivity = data.tps.map((d, i) => d / (d + data.fns[i]));
        data.specificity = data.tns.map((d, i) => d / (d + data.fps[i]));

        // Create line generator
        const line = d3
          .line()
          .x((d) => xScale(d.fpr))
          .y((d) => yScale(d.tpr));

        // ROC curve line
        svgRef
          .append("path")
          .datum(
            data.fpr.map((d, i) => ({ fpr: d, tpr: data.tpr[i] }))
          )
          .attr("fill", "none")
          .attr("stroke", curveColors[className](class_name))
          .attr("stroke-width", 3)
          .attr("d", line)
          .attr("class", `roc-line ${class_name}`)
          .on("mouseover", function (e, d) {
            d3.select(this);
            tooltip.transition().duration(200).style("opacity", 1);
            tooltip
              .html(
                `${class_name} vs. rest<br/>
                AUC: ${data.auc.toFixed(2)}`
              )
              .style("left", e.clientX + 5 + "px")
              .style("top", e.clientY + 5 + "px");
          });

        // ROC curve points
        svgRef
          .selectAll(`.scatter-circle ${class_name}`)
          .data(
            data.tpr.map((d, i) => ({
              fpr: data.fpr[i],
              tpr: d,
              auc: data.auc,
              tps: data.tps[i],
              fps: data.fps[i],
              tns: data.tns[i],
              fns: data.fns[i],
              ppv: data.ppv[i],
              npv: data.npv[i],
              sensitivity: data.sensitivity[i],
              specificity: data.specificity[i],
              outcome: data.posneg_mat.map((row) => row[i]),
              index: i,
            }))
          )
          .join("circle")
          .attr("class", `roc-circle ${class_name}`)
          .attr("cx", (d) => xScale(d.fpr))
          .attr("cy", (d) => yScale(d.tpr))
          .attr("r", 2)
          .attr("fill", curveColors[className](class_name))
          .attr("stroke", "black")
          .style("cursor", "pointer")
          .attr("fill-opacity", 1)
          .on("mouseover", function (e, d) {
            if (
              !(
                d.index == selectedPoint.index &&
                class_name == selectedPoint.class_name
              )
            ) {
              d3.select(this).attr("r", 6);
            }
            tooltip.transition().duration(200).style("opacity", 1);
            tooltip
              .html(
                `${class_name} vs. rest<br/>
                <div class="flex justify-between">
                  <div>AUC: ${d.auc.toFixed(2)}</div>
                  <div>click to view</div>
                </div>
                <table class="border border-collapse mt-2">
                  <tr>
                    <td class="border px-1">${class_name}</td>
                    <td class="border px-1">TPs: ${d.tps}</td>
                    <td class="border px-1">FPs: ${d.fps}</td>
                    <td class="border px-1">PPV: ${d.ppv.toFixed(2)}</td>
                  </tr>
                  <tr>
                    <td class="border px-1">rest</td>
                    <td class="border px-1">FNs: ${d.fns}</td>
                    <td class="border px-1">TNs: ${d.tns}</td>
                    <td class="border px-1">NPV: ${d.npv.toFixed(2)}</td>
                  </tr>
                  <tr>
                    <td class="border px-1"></td>
                    <td class="border px-1">Sens: ${d.sensitivity.toFixed(
                      2
                    )}</td>
                    <td class="border px-1">Spec: ${d.specificity.toFixed(
                      2
                    )}</td>
                    <td class="border px-1"></td>
                  </tr>
                </table>`
              )
              .style("left", e.clientX + 5 + "px")
              .style("top", e.clientY + 5 + "px");
          })
          .on("mouseout", function (e, d) {
            if (
              !(
                d.index == selectedPoint.index &&
                class_name == selectedPoint.class_name
              )
            ) {
              d3.select(this).attr("r", 2);
            }
            tooltip.transition().duration(200).style("opacity", 0);
          })
          .on("click", function (e, d) {
            if (
              d.index == selectedPoint.index &&
              class_name == selectedPoint.class_name
            ) {
              selectedPoint = { index: null, class_name: null };
              d3.select(this)
                .attr("r", 2)
                .attr("stroke", "black")
                .attr("stroke-width", 1);
            } else {
              svgRef
                .selectAll(`.roc-circle`)
                .attr("r", 2)
                .attr("stroke", "black")
                .attr("stroke-width", 1);

              selectedPoint = { index: d.index, class_name: class_name };

              d3.select(this)
                .attr("r", 6)
                .attr("stroke", "red")
                .attr("stroke-width", 4)
                .raise();
            }

            if (selectedPoint.index && selectedPoint.class_name) {
              let indices = d.outcome.reduce(
                (acc, cur, idx) => (cur === 0 ? [...acc, idx] : acc),
                []
              );
              let currentIndices =
                model.model_params[model_type + "_indices"];
              let currentColors = [];
              for (let n = 0; n < currentIndices.length; n++) {
                if (d.outcome[n] == 0) {
                  currentColors[n] = "lime";
                } else if (d.outcome[n] == 1) {
                  currentColors[n] = "blue";
                } else if (d.outcome[n] == 2) {
                  currentColors[n] = "red";
                } else if (d.outcome[n] == 3) {
                  currentColors[n] = "yellow";
                }
              }
              setSelectedGroup({
                indices: currentIndices,
                colors: currentColors,
              });
            } else {
              setSelectedGroup({ indices: [], colors: [] });
            }
          });

        svgRef.on("mouseout", function (e, d) {
          tooltip.transition().duration(200).style("opacity", 0);
        });
      });

      return () => {
        tooltip.remove();
      };
    }, [dataGlobal]);
  }

  useROCCurve(svg);

  // Interactive legend
  useInteractiveLegend(
    svg,
    "roc",
    dimensions,
    model.classes,
    curveColors[className],
    "categorical",
    model.model_params.response_variable_names[0],
    [],
    [],
    dataGlobal
  );

  return (
    <>
      <div>
        {figureTitle && (
          <h3 className="text-xl font-bold mb-4">{figureTitle}</h3>
        )}
        <div className="relative w-full h-full">
          <svg className="w-full h-full" ref={d3Container} />
        </div>
      </div>
    </>
  );
}

export default ROCCurve;
