import React, { useState } from "react";
import { LineChart } from "@mui/x-charts/LineChart";
import { axisClasses } from "@mui/x-charts/ChartsAxis";
import { createTheme, ThemeProvider } from "@mui/material/styles";

// Create a dark theme
const darkTheme = createTheme({
  palette: {
    mode: "dark", // Dark mode
    primary: {
      main: "#90caf9", // Light blue for primary text
    },
    secondary: {
      main: "#f48fb1", // Light pink for secondary text
    },
  },
  components: {
    MuiTypography: {
      styleOverrides: {
        root: {
          color: "#fff", // Ensure all text is white in dark mode
        },
      },
    },
  },
});

const LaunchIt = () => {
  const [stages, setStages] = useState([]);
  const [totalAltitude, setTotalAltitude] = useState(0);
  const [totalVelocity, setTotalVelocity] = useState(0);
  const [totalAltitudeZeroVelocity, setTotalAltitudeZeroVelocity] = useState(0); // State to track altitude at zero velocity
  const [graphData, setGraphData] = useState({
    time: [],
    velocity: [],
    altitude: [],
  });

  const GRAVITY = 9.81; // m/s²

  const addStage = (event) => {
    event.preventDefault();

    const isp = parseFloat(event.target.isp.value);
    const m0 = parseFloat(event.target.m0.value);
    const mp = parseFloat(event.target.mp.value);
    const mpDot = parseFloat(event.target.mpDot.value);

    if (isNaN(isp) || isNaN(m0) || isNaN(mp) || isNaN(mpDot)) {
      alert("Please enter valid numbers for all fields.");
      return;
    }

    if (mp <= 0) {
      alert("Propellant mass (mₚ) must be greater than 0.");
      return;
    }

    const stage = calculateStage(isp, m0, mp, mpDot);
    setStages((prevStages) => [...prevStages, stage]);

    const updated = appendGraphData(
      stage,
      totalVelocity,
      totalAltitude,
      graphData.time.at(-1) || 0,
      false,
    );

    setTotalAltitude(updated.cumulativeAltitude);
    setTotalVelocity(updated.cumulativeVelocity);
  };

  const plotExample = () => {
    const exampleStages = [
      { isp: 3000, m0: 10000, mp: 6272, mpDot: 200 },
      { isp: 2300, m0: 2160, mp: 1568, mpDot: 60 },
    ];

    let currentAltitude = totalAltitude;
    let currentVelocity = totalVelocity;
    let currentTime = graphData.time.at(-1) || 0;

    const newStages = [];

    exampleStages.forEach(({ isp, m0, mp, mpDot }) => {
      const stage = calculateStage(isp, m0, mp, mpDot);
      newStages.push(stage);

      const updated = appendGraphData(
        stage,
        currentVelocity,
        currentAltitude,
        currentTime,
      );
      currentVelocity = updated.cumulativeVelocity;
      currentAltitude = updated.cumulativeAltitude;
      currentTime = updated.cumulativeTime;
    });

    setStages((prevStages) => [...prevStages, ...newStages]);
    setTotalAltitude(currentAltitude);
    setTotalVelocity(currentVelocity);
  };

  const calculateStage = (isp, m0, mp, mpDot) => {
    const massNumber = m0 / (m0 - mp);
    const tf = mp / mpDot;
    const deltaV =
      isp * Math.log(massNumber) -
      (GRAVITY * (massNumber - 1) * (m0 - mp)) / mpDot;
    const deltaH =
      ((isp * (m0 - mp)) / mpDot) * (massNumber - 1 - Math.log(massNumber)) -
      0.5 * GRAVITY * ((m0 - mp) / mpDot) ** 2 * (massNumber - 1) ** 2;

    return { isp, m0, mp, mpDot, tf, deltaV, deltaH };
  };

  const appendGraphData = (
    stage,
    startVelocity,
    startAltitude,
    startTime,
    isLastStage = false,
  ) => {
    const timeData = [];
    const velocityData = [];
    const altitudeData = [];

    let cumulativeVelocity = startVelocity;
    let cumulativeAltitude = startAltitude;
    let cumulativeTime = startTime;

    // Calculate thrust phase data
    for (let t = 0; t <= stage.tf; t += stage.tf / 100) {
      const mass = stage.m0 - stage.mpDot * t;
      const velocity =
        stage.isp * Math.log(stage.m0 / mass) -
        GRAVITY * t +
        cumulativeVelocity;
      const altitude =
        (stage.isp / stage.mpDot) *
          ((stage.m0 - stage.mpDot * t) *
            Math.log((stage.m0 - stage.mpDot * t) / stage.m0) +
            stage.mpDot * t) -
        0.5 * GRAVITY * t * t +
        cumulativeAltitude;

      timeData.push(cumulativeTime + t);
      velocityData.push(velocity);
      altitudeData.push(altitude);
    }

    cumulativeVelocity = velocityData[velocityData.length - 1];
    cumulativeAltitude = altitudeData[altitudeData.length - 1];
    cumulativeTime += stage.tf;

    // Calculate inertial data (if this is the last stage)
    const inertialTimeData = [];
    const inertialVelocityData = [];
    const inertialAltitudeData = [];

    let zeroVelocityAltitude = 0;

    if (isLastStage) {
      let t = 0;
      let velocity = cumulativeVelocity;
      let altitude = cumulativeAltitude;

      while (velocity > 0) {
        t += 10; // Time step for inertia
        velocity = cumulativeVelocity - GRAVITY * t;
        altitude =
          cumulativeAltitude + cumulativeVelocity * t - 0.5 * GRAVITY * t * t;

        if (velocity > 0) {
          inertialTimeData.push(cumulativeTime + t);
          inertialVelocityData.push(velocity);
          inertialAltitudeData.push(altitude);
        }

        // Track altitude at zero velocity
        if (velocity <= 0 && zeroVelocityAltitude === 0) {
          zeroVelocityAltitude = altitude;
        }
      }

      // Update total altitude at zero velocity
      if (zeroVelocityAltitude !== 0) {
        setTotalAltitudeZeroVelocity(zeroVelocityAltitude);
      }
    }

    // Update graph data
    setGraphData((prevData) => {
      const filteredTime = isLastStage
        ? prevData.time.filter((t) => t < startTime) // Remove inertial data from the previous call
        : prevData.time;
      const filteredVelocity = isLastStage
        ? prevData.velocity.filter((_, i) => prevData.time[i] < startTime)
        : prevData.velocity;
      const filteredAltitude = isLastStage
        ? prevData.altitude.filter((_, i) => prevData.time[i] < startTime)
        : prevData.altitude;

      return {
        time: [
          ...filteredTime,
          ...timeData,
          ...(isLastStage ? inertialTimeData : []),
        ],
        velocity: [
          ...filteredVelocity,
          ...velocityData,
          ...(isLastStage ? inertialVelocityData : []),
        ],
        altitude: [
          ...filteredAltitude,
          ...altitudeData,
          ...(isLastStage ? inertialAltitudeData : []),
        ],
      };
    });

    return { cumulativeVelocity, cumulativeAltitude, cumulativeTime };
  };

  const computeInertia = () => {
    // const lastStage = stages[stages.length - 1];
    const blankStage = calculateStage(0, 1, 1, 1);

    if (!blankStage) {
      alert("No stages to compute inertia.");
      return;
    }

    // Calculate inertia for the last stage
    appendGraphData(
      blankStage,
      totalVelocity,
      totalAltitude,
      graphData.time.at(-1) || 0,
      true,
    );
  };

  return (
    <ThemeProvider theme={darkTheme}>
      <div
        className="container"
        style={{
          display: "flex",
          minHeight: "100vh",
        }}
      >
        {/* Left Column */}
        <div style={{ width: "30%", paddingTop: "10px", overflowY: "auto" }}>
          <h1>Rocket Stage Calculator</h1>

          <form onSubmit={addStage}>
            <label>
              Specific Impulse (ISP):
              <input type="number" step="0.1" name="isp" required />
            </label>
            <br />
            <label>
              Total Mass (m₀):
              <input type="number" step="0.1" name="m0" required />
            </label>
            <br />
            <label>
              Propellant Mass (mₚ):
              <input type="number" step="0.1" name="mp" required />
            </label>
            <br />
            <label>
              Propellant Flow Rate (ṁₚ):
              <input type="number" step="0.1" name="mpDot" required />
            </label>
            <br />
            <div style={{ display: "flex", gap: "10px" }}>
              <button
                type="submit"
                style={{
                  color: "white",
                  backgroundColor: "#90caf9",
                  border: "2px solid #90caf9",
                  borderRadius: "4px",
                  padding: "3px 10px",
                  cursor: "pointer",
                }}
              >
                Add Stage
              </button>
              <button
                type="button"
                style={{
                  color: "white",
                  backgroundColor: "#90caf9",
                  border: "2px solid #90caf9",
                  borderRadius: "4px",
                  padding: "3px 10px",
                  cursor: "pointer",
                }}
                onClick={plotExample}
              >
                Example
              </button>
            </div>
          </form>

          <h2>Computed Results</h2>
          <ul>
            {stages.map((stage, index) => (
              <li key={index}>
                <h4>Stage {index + 1}</h4>
                <p>Time of Combustion: {stage.tf.toFixed(2)} s</p>
                <p>Increment of Velocity (ΔV): {stage.deltaV.toFixed(2)} m/s</p>
                <p>Increment of Altitude (Δh): {stage.deltaH.toFixed(2)} m</p>
              </li>
            ))}
          </ul>

          <div>
            <h3>Total Results</h3>
            <p>Total Velocity: {totalVelocity.toFixed(2)} m/s</p>
            <p>Total Altitude: {totalAltitude.toFixed(2)} m</p>
            <p>
              Total Altitude at Zero Velocity:{" "}
              {totalAltitudeZeroVelocity.toFixed(2)} m
            </p>
          </div>

          <button
            onClick={computeInertia}
            style={{
              color: "white",
              backgroundColor: "#90caf9",
              border: "2px solid #90caf9",
              borderRadius: "4px",
              padding: "3px 10px",
              cursor: "pointer",
            }}
          >
            Compute Inertia
          </button>
        </div>

        {/* Right Column */}
        <div style={{ width: "70%", paddingTop: "10px" }}>
          <div className="chart-container">
            <h2>Rocket performance for all the stages</h2>
            <LineChart
              xAxis={[{ data: graphData.time, label: "Time (s)" }]}
              yAxis={[
                {
                  label: "Velocity (m/s)",
                  id: "velocity-axis",
                },
                {
                  label: "Altitude (m)",
                  id: "altitude-axis",
                },
              ]}
              leftAxis="velocity-axis"
              rightAxis="altitude-axis"
              margin={{ right: 80, left: 80 }}
              sx={{
                [`& .${axisClasses.left} .${axisClasses.label}`]: {
                  transform: "translateX(-40px)",
                },
                [`& .${axisClasses.right} .${axisClasses.label}`]: {
                  transform: "translateX(40px)",
                },
              }}
              series={[
                {
                  data: graphData.velocity,
                  label: "Velocity (m/s)",
                  color: "rgb(75, 192, 192)",
                  yAxisId: "velocity-axis",
                },
                {
                  data: graphData.altitude,
                  label: "Altitude (m)",
                  color: "rgb(192, 75, 75)",
                  yAxisId: "altitude-axis",
                },
              ]}
              height={800}
            />
          </div>
        </div>
      </div>
    </ThemeProvider>
  );
};

export default LaunchIt;
