import * as d3 from 'd3';
import {customColors} from './visualizationFlags';

// Function to handle resizing
export const handleResize = (divRef, setDimensions) => {
  if (divRef.current) {
    setDimensions({
      width: divRef.current.clientWidth,
      height: divRef.current.clientHeight,
    });
  }
};

// Function to create SVG
export const createSVG = (svgRef, dimensions, margin) => {
  const svg = d3.select(svgRef.current);
  svg.selectAll('*').remove();

  svg.attr('width', dimensions.width + margin.left + margin.right)
    .attr('height', dimensions.height + margin.top + margin.bottom)
    .style('background-color', 'black')
    .style('border-radius', '10px');

  return svg;
};

// Helper function to format numbers
const formatNumber = (d) => {
  if (d >= 1e9) return (d / 1e9).toFixed(1) + "B"; // Billion
  if (d >= 1e6) return (d / 1e6).toFixed(1) + "M"; // Million
  if (d >= 1e3) return (d / 1e3).toFixed(1) + "k"; // Thousand
  return d.toString(); // Smaller numbers remain unchanged
};

// Function to create X Axis
export const createXAxis = (svg, data, xAxisColumn, xAxisType, dimensions, margin, grid) => {
  let xScale;

  if (xAxisType === 'float') {
    const extent = d3.extent(data, d => d[xAxisColumn]);
    const bufferPercentage = 0.1;
    const buffer = (extent[1] - extent[0]) * bufferPercentage;

    xScale = d3.scaleLinear()
      .domain([extent[0] - buffer, extent[1] + buffer])
      .range([margin.left, dimensions.width+margin.left])
      .nice();
  } 
  
  else if (xAxisType === 'varchar') {
    xScale = d3.scaleBand()
      .domain(data.map(d => d[xAxisColumn]))
      .range([margin.left, dimensions.width+margin.left])
      .padding(0.1);
  } 
  
  else if (xAxisType === 'date') {
    xScale = d3.scaleTime()
      .domain(d3.extent(data, d => new Date(d[xAxisColumn])))
      .range([margin.left, dimensions.width+margin.left]);
  }

  const xAxis = d3.axisBottom(xScale);

  if (xAxisType !== 'varchar') {
    xAxis.ticks(5);
  }
  if (xAxisType === 'float') {
    xAxis.ticks(5).tickFormat(d => formatNumber(d));
  } 
  if (xAxisType === 'date') {
    xAxis.tickFormat(d3.timeFormat("%Y-%m-%d"));
  }

  svg.append("g")
    .attr("transform", `translate(0, ${margin.top + dimensions.height})`)
    .call(xAxis)
    .style("font-size","10px")
    .style("color", "#F9FAFB");

  svg.append("text")
    .attr("x", margin.left + (dimensions.width / 2))
    .attr("y", margin.top + dimensions.height + 35)
    .style("text-anchor", "middle")
    .text(xAxisColumn)
    .style("font-size", "12px")
    .style("font-weight","bold")
    .style("fill", "#F9FAFB");

  // Defining the axis grid
  if (grid) {
    const xAxisGrid = d3.axisBottom(xScale)
      .tickSize(-dimensions.height)
      .tickFormat('')
      .ticks(5);

    svg.append("g")
      .attr("class", "x-axis-grid")
      .attr("transform", `translate(0, ${margin.top + dimensions.height})`)
      .call(xAxisGrid)
      .style("color", "#D7D7D7")
      .style("stroke-dasharray", "2,2")
      .style("opacity", "0.3");
  }

  return {svg, xScale};
};

// Function to create Y Axis
export const createYAxis = (svg, data, yAxisColumn, yAxisType, dimensions, margin, grid) => {
  let yScale;

  if (yAxisType === 'float') {
    const extent = d3.extent(data, d => d[yAxisColumn]);
    const bufferPercentage = 0.4;
    const buffer = (extent[1] - extent[0]) * bufferPercentage;

    yScale = d3.scaleLinear()
      .domain([extent[0] - buffer, extent[1] + buffer])
      .range([dimensions.height + margin.top, margin.top])
      .nice();
  } 
  
  else if (yAxisType === 'varchar') {
    yScale = d3.scaleBand()
      .domain(data.map(d => d[yAxisColumn]))
      .range([dimensions.height + margin.top, margin.top])
      .padding(0.1);
  } 
  
  else if (yAxisType === 'date') {
    yScale = d3.scaleTime()
      .domain(d3.extent(data, d => new Date(d[yAxisColumn])))
      .range([dimensions.height + margin.top, margin.top]);
  }

  const yAxis = d3.axisLeft(yScale);

  if (yAxisType !== 'varchar') {
    yAxis.ticks(5);
  }
  if (yAxisType === 'float') {
    yAxis.ticks(5).tickFormat(d => formatNumber(d));
  } 
  else if (yAxisType === 'date') {
    yAxis.tickFormat(d3.timeFormat("%Y-%m-%d"));
  }

  svg.append("g")
    .attr("transform", `translate(${margin.left}, 0)`)
    .call(yAxis)
    .style("font-size","10px")
    .style("color", "#F9FAFB");

  svg.append("text")
    .attr("transform", "rotate(-90)")
    .attr("y", margin.left - 60)
    .attr("x", - (margin.top + (dimensions.height / 2)))
    .attr("dy", "1em")
    .style("text-anchor", "middle")
    .text(yAxisColumn)
    .style("font-size", "12px")
    .style("font-weight","bold")
    .style("fill", "#F9FAFB");

  // Defining the axis grid
  if (grid) {
  const yAxisGrid = d3.axisLeft(yScale)
    .tickSize(-dimensions.width)
    .tickFormat('')
    .ticks(5);

  svg.append("g")
    .attr("class", "y-axis-grid")
    .attr("transform", `translate(${margin.left}, 0)`)
    .call(yAxisGrid)
    .style("color", "#D7D7D7")
    .style("stroke-dasharray", "2,2")
    .style("opacity", "0.3");
  }

  return {svg, yScale};
};

// Chart Renditions
// Bubble Chart
export const renderBubbleChart = (svg, vizdata, data, scales) => {
  const xScale = scales.xScale;
  const yScale = scales.yScale;

  const zScale = d3.scaleSqrt()
    .domain(d3.extent(data, d => d[vizdata.zaxis]))
    .range([2, 20]);

  const colorScale = d3.scaleOrdinal()
    .domain(data.map(d => d[vizdata.color]))
    .range(customColors);

  // Append the bubbles
  svg.selectAll(".bubble")
    .data(data)
    .enter().append("circle")
    .attr("class", "bubble")
    .attr("cx", d => xScale(d[vizdata.xaxis]))
    .attr("cy", d => yScale(d[vizdata.yaxis]))
    .attr("r", d => zScale(d[vizdata.zaxis]))
    .style("fill", d => colorScale(d[vizdata.color]))
    .attr("fill-opacity", "0.3")
    .attr("stroke", d => colorScale(d[vizdata.color]))
    .style("stroke-width", "1.5")
    .attr("stroke-opacity", "1");

  return svg;
};

// Pie Chart
export const renderPieChart = (svg, dimensions, vizdata, data) => {

  // Set up radius
  const radius = Math.min(dimensions.width, dimensions.height) / 3;

  // Count occurrences of each category
  const categoryCounts = d3.rollup(
    data,
    v => v.length,
    d => d[vizdata.xaxis]
  );

  const pieData = Array.from(categoryCounts, ([key, value]) => ({ key, value }));

  // Create a color scale
  const colorScale = d3.scaleOrdinal()
    .domain(pieData.map(d => d.key))
    .range(customColors);

  // Create the pie generator
  const pie = d3.pie()
    .value(d => d.value)
    .sort(null);

  // Create the arc generator
  const arc = d3.arc()
    .innerRadius(0)
    .outerRadius(radius);

  // Create the outer group element
  const g = svg.append("g")
    .attr("transform", `translate(${dimensions.width / 2}, ${dimensions.height / 2})`);

  // Bind data to pie chart
  const arcs = g.selectAll(".arc")
    .data(pie(pieData))
    .enter().append("g")
    .attr("class", "arc");

  // Append the path (pie slices)
  arcs.append("path")
    .attr("d", arc)
    .attr("fill", d => colorScale(d.data.key))
    .attr("fill-opacity", "0.85")
    .attr("stroke", "black")
    .style("stroke-width", "1.3");

  return svg;
};

// Box Plot
export const renderBoxPlot = (svg, vizdata, data, scales) => {
  const xScale = scales.xScale;
  const yScale = scales.yScale;

  // Compute summary statistics for each category
  const summaryStats = d3.rollups(
    data,
    values => {
      const sortedValues = values.map(d => d[vizdata.yaxis]).sort(d3.ascending);
      const q1 = d3.quantile(sortedValues, 0.25);
      const median = d3.quantile(sortedValues, 0.5);
      const q3 = d3.quantile(sortedValues, 0.75);
      const interQuantileRange = q3 - q1;
      const min = q1 - 1.5 * interQuantileRange;
      const max = q3 + 1.5 * interQuantileRange;
      return { q1, median, q3, interQuantileRange, min, max };
    },
    d => d[vizdata.xaxis]
  );

  // Draw the boxplots
  summaryStats.forEach(([key, stats], i) => {
    const boxWidth = xScale.bandwidth() * 0.25;  // Make the boxes thinner

    // Draw the box
    svg.append("rect")
      .attr("x", xScale(key) + (xScale.bandwidth() - boxWidth) / 2)
      .attr("y", yScale(stats.q3))
      .attr("height", yScale(stats.q1) - yScale(stats.q3))
      .attr("width", boxWidth)
      .attr("stroke", customColors[i % customColors.length])
      .style("fill", customColors[i % customColors.length])
      .attr("fill-opacity", "0.6")
      .style("stroke-width", "1.5")
      .attr("stroke-opacity", "1");

    // Draw the median line
    svg.append("line")
      .attr("x1", xScale(key) + (xScale.bandwidth() - boxWidth) / 2)
      .attr("x2", xScale(key) + (xScale.bandwidth() + boxWidth) / 2)
      .attr("y1", yScale(stats.median))
      .attr("y2", yScale(stats.median))
      .attr("stroke", customColors[i % customColors.length]);

      // Draw the whiskers
      // Whisker from min to q1
      svg.append("line")
      .attr("x1", xScale(key) + xScale.bandwidth() / 2)
      .attr("x2", xScale(key) + xScale.bandwidth() / 2)
      .attr("y1", yScale(stats.min))
      .attr("y2", yScale(stats.q1))
      .attr("stroke", customColors[i % customColors.length]);

    // Whisker from q3 to max
    svg.append("line")
      .attr("x1", xScale(key) + xScale.bandwidth() / 2)
      .attr("x2", xScale(key) + xScale.bandwidth() / 2)
      .attr("y1", yScale(stats.q3))
      .attr("y2", yScale(stats.max))
      .attr("stroke", customColors[i % customColors.length]);

    // Draw the min line
    svg.append("line")
      .attr("x1", xScale(key) + (xScale.bandwidth() - boxWidth) / 2)
      .attr("x2", xScale(key) + (xScale.bandwidth() + boxWidth) / 2)
      .attr("y1", yScale(stats.min))
      .attr("y2", yScale(stats.min))
      .attr("stroke", customColors[i % customColors.length]);

    // Draw the max line
    svg.append("line")
      .attr("x1", xScale(key) + (xScale.bandwidth() - boxWidth) / 2)
      .attr("x2", xScale(key) + (xScale.bandwidth() + boxWidth) / 2)
      .attr("y1", yScale(stats.max))
      .attr("y2", yScale(stats.max))
      .attr("stroke", customColors[i % customColors.length]);
  });

  return svg;
};

// Function to create an Area Chart
export const renderAreaChart = (svg, dimensions, margin, vizdata, data, scales) => {
  const xScale = scales.xScale;
  const yScale = scales.yScale;

  // Create the area generator
  const area = d3.area()
    .x(d => xScale(new Date(d[vizdata.xaxis])))
    .y0(dimensions.height+margin.top)
    .y1(d => yScale(d[vizdata.yaxis]));

  // Append the area path to the SVG
  svg.append("path")
    .datum(data)
    .attr("stroke", customColors[0])
    .style("fill", customColors[0])
    .attr("fill-opacity", "0.5")
    .style("stroke-width", "1.5")
    .attr("stroke-opacity", "1")
    .attr("d", area);

  return svg;
};
