"""
Quadcopter 6-DOF Flight Simulator
===================================
Full nonlinear rigid-body dynamics with PD attitude/position control.

State vector (12 elements):
  [x, y, z, φ, θ, ψ, vx, vy, vz, p, q, r]

  (x, y, z)   – position in inertial frame (m)
  (φ, θ, ψ)   – roll, pitch, yaw (rad)
  (vx, vy, vz)– linear velocity, inertial frame (m/s)
  (p, q, r)   – body-frame angular rates (rad/s)

Rotor layout (top view, + configuration):
         1 (front, CW)
  4 (left, CCW)  +  2 (right, CCW)
         3 (back,  CW)

Thrust model:   F_i  = k · ωᵢ²
Torque model:   τφ   = a·k·(ω1² − ω3²)
                τθ   = a·k·(ω2² − ω4²)
                τψ   = b·(ω1² − ω2² + ω3² − ω4²)

Requires: numpy, matplotlib
Run with: python quadcopter_sim.py
"""

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D   # noqa: F401 (registers 3d projection)

# ═══════════════════════════════════════════════════════════
# PHYSICAL CONSTANTS
# ═══════════════════════════════════════════════════════════

G        = 9.81          # gravitational acceleration (m/s²)
MASS     = 1.387         # vehicle mass (kg)
Ix       = 0.013912      # moment of inertia, roll  axis (kg·m²)
Iy       = 0.013912      # moment of inertia, pitch axis (kg·m²)
Iz       = 0.027823      # moment of inertia, yaw   axis (kg·m²)
I_TENSOR = np.diag([Ix, Iy, Iz])
I_INV    = np.linalg.inv(I_TENSOR)

K_THRUST = 1.36e-5       # thrust coefficient  k  (N·s²/rad²)
K_DRAG   = 5.00e-7       # drag-torque coefficient b  (N·m·s²/rad²)
ARM      = 0.175         # rotor arm length a (m)

# Motor speed limits
OMEGA_MAX = 900.0        # rad/s  (~8600 RPM)
OMEGA_MIN =  50.0        # rad/s  idle

# ─── Allocation matrix ─────────────────────────────────────
# Maps motor squared-speeds → [T, τφ, τθ, τψ]:
#
#  [T  ]   [k,    k,   k,    k  ] [ω1²]
#  [τφ ] = [ak,   0,  -ak,   0  ] [ω2²]
#  [τθ ]   [0,   ak,   0,   -ak ] [ω3²]
#  [τψ ]   [b,  -b,    b,   -b  ] [ω4²]
#
ALLOC = np.array([
    [K_THRUST,       K_THRUST,       K_THRUST,       K_THRUST      ],
    [ARM*K_THRUST,   0,             -ARM*K_THRUST,    0             ],
    [0,              ARM*K_THRUST,   0,              -ARM*K_THRUST  ],
    [K_DRAG,        -K_DRAG,         K_DRAG,         -K_DRAG        ],
])
ALLOC_INV = np.linalg.inv(ALLOC)

# ═══════════════════════════════════════════════════════════
# PD CONTROLLER GAINS
# ═══════════════════════════════════════════════════════════

KP_Z    = 12.0;  KD_Z    = 8.0    # altitude
KP_XY   =  1.8;  KD_XY   = 2.5    # horizontal position
KP_ATT  = 12.0;  KD_ATT  = 5.0    # roll & pitch attitude
KP_YAW  =  6.0;  KD_YAW  = 2.0    # yaw
MAX_TILT = 0.35                    # max desired roll/pitch (rad, ~20°)

# ═══════════════════════════════════════════════════════════
# KINEMATICS HELPERS
# ═══════════════════════════════════════════════════════════

def rotation_matrix(phi, theta, psi):
    """
    ZYX Euler angles → 3×3 rotation matrix R (body → inertial).

    Columns are the body-frame axes expressed in the inertial frame.
    """
    cp, sp = np.cos(phi),   np.sin(phi)
    ct, st = np.cos(theta), np.sin(theta)
    cy, sy = np.cos(psi),   np.sin(psi)
    return np.array([
        [ cy*ct,  cy*st*sp - sy*cp,  cy*st*cp + sy*sp ],
        [ sy*ct,  sy*st*sp + cy*cp,  sy*st*cp - cy*sp ],
        [-st,     ct*sp,              ct*cp             ],
    ])


def euler_rates(phi, theta, p, q, r):
    """
    Body angular rates (p, q, r) → Euler angle rates (φ̇, θ̇, ψ̇).

    From the kinematic relationship:
        [φ̇]   [1  sin(φ)tan(θ)   cos(φ)tan(θ)] [p]
        [θ̇] = [0  cos(φ)        -sin(φ)       ] [q]
        [ψ̇]   [0  sin(φ)/cos(θ)  cos(φ)/cos(θ)] [r]
    """
    cp, sp = np.cos(phi), np.sin(phi)
    ct = np.cos(theta)
    tt = np.tan(theta)
    if abs(ct) < 1e-6:          # avoid gimbal-lock singularity
        ct = 1e-6
    phi_dot   = p + (q*sp + r*cp) * tt
    theta_dot = q*cp - r*sp
    psi_dot   = (q*sp + r*cp) / ct
    return phi_dot, theta_dot, psi_dot


# ═══════════════════════════════════════════════════════════
# MOTOR ALLOCATION
# ═══════════════════════════════════════════════════════════

def motor_speeds_from_wrench(T, tau_phi, tau_theta, tau_psi):
    """
    Convert desired [thrust T, torques τφ τθ τψ] → motor speeds (rad/s).

    Inverts the allocation matrix, clamps to physical limits.
    """
    u        = np.array([T, tau_phi, tau_theta, tau_psi])
    omega_sq = ALLOC_INV @ u
    omega_sq = np.clip(omega_sq, OMEGA_MIN**2, OMEGA_MAX**2)
    return np.sqrt(omega_sq)


# ═══════════════════════════════════════════════════════════
# DYNAMICS
# ═══════════════════════════════════════════════════════════

def dynamics(state, omegas):
    """
    Full 6-DOF quadcopter equations of motion.

    Parameters
    ----------
    state  : ndarray(12,)  current state [x,y,z, φ,θ,ψ, vx,vy,vz, p,q,r]
    omegas : ndarray(4,)   motor angular speeds  ω1…ω4 (rad/s)

    Returns
    -------
    dstate : ndarray(12,)  time derivative of state
    """
    x, y, z, phi, theta, psi, vx, vy, vz, p, q, r = state
    w1, w2, w3, w4 = omegas

    # ── Thrust model: F_i = k·ωᵢ² ──────────────────────────
    T_total   = K_THRUST * (w1**2 + w2**2 + w3**2 + w4**2)

    # ── Torque model ─────────────────────────────────────────
    # τφ = a·k·(ω1² − ω3²)  — rolling moment
    # τθ = a·k·(ω2² − ω4²)  — pitching moment
    # τψ = b·(ω1² − ω2² + ω3² − ω4²) — yawing drag torque
    tau_phi   = ARM * K_THRUST * (w1**2 - w3**2)
    tau_theta = ARM * K_THRUST * (w2**2 - w4**2)
    tau_psi   = K_DRAG         * (w1**2 - w2**2 + w3**2 - w4**2)
    tau       = np.array([tau_phi, tau_theta, tau_psi])

    # ── Rotation matrix: body frame → inertial frame ─────────
    R = rotation_matrix(phi, theta, psi)

    # ── Translational dynamics ───────────────────────────────
    # ẍ = [0, 0, −g]ᵀ + (1/m)·R·T_body
    # T_body = [0, 0, T_total]ᵀ  (thrust along body z-axis)
    T_body     = np.array([0.0, 0.0, T_total])
    T_inertial = R @ T_body
    ax = T_inertial[0] / MASS
    ay = T_inertial[1] / MASS
    az = -G + T_inertial[2] / MASS

    # ── Rotational dynamics ──────────────────────────────────
    # ω̇ = I⁻¹(τ − ω × (I·ω))
    # The cross product ω×(Iω) encodes gyroscopic coupling between axes.
    omega_body = np.array([p, q, r])
    gyro_term  = np.cross(omega_body, I_TENSOR @ omega_body)
    alpha      = I_INV @ (tau - gyro_term)   # [ṗ, q̇, ṙ]

    # ── Euler angle kinematics ────────────────────────────────
    phi_dot, theta_dot, psi_dot = euler_rates(phi, theta, p, q, r)

    return np.array([
        vx, vy, vz,                       # position derivatives
        phi_dot, theta_dot, psi_dot,       # Euler angle derivatives
        ax, ay, az,                        # velocity derivatives
        alpha[0], alpha[1], alpha[2],      # angular rate derivatives
    ])


# ═══════════════════════════════════════════════════════════
# RK4 INTEGRATOR
# ═══════════════════════════════════════════════════════════

def rk4_step(state, omegas, dt):
    """
    4th-order Runge-Kutta integration step.

    Uses the same motor commands throughout the step (zero-order hold).
    """
    k1 = dynamics(state,            omegas)
    k2 = dynamics(state + 0.5*dt*k1, omegas)
    k3 = dynamics(state + 0.5*dt*k2, omegas)
    k4 = dynamics(state +     dt*k3, omegas)
    return state + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)


# ═══════════════════════════════════════════════════════════
# PD CONTROLLER
# ═══════════════════════════════════════════════════════════

def controller(state, setpoint, noise_std=0.0):
    """
    Cascaded PD controller: state × setpoint → [T, τφ, τθ, τψ].

    Outer loop: position error → desired tilt angles + thrust.
    Inner loop: attitude error → body torques.

    Parameters
    ----------
    state     : ndarray(12,)  current vehicle state
    setpoint  : dict          {'x', 'y', 'z', 'psi'}  desired values
    noise_std : float         std-dev of Gaussian sensor noise (m / rad)

    Returns
    -------
    T, tau_phi, tau_theta, tau_psi : floats
    """
    x, y, z, phi, theta, psi, vx, vy, vz, p, q, r = state

    # Optional sensor noise (simulates IMU / GPS imperfections)
    if noise_std > 0.0:
        x   += np.random.normal(0, noise_std)
        y   += np.random.normal(0, noise_std)
        z   += np.random.normal(0, noise_std * 0.5)
        phi   += np.random.normal(0, noise_std * 0.08)
        theta += np.random.normal(0, noise_std * 0.08)
        psi   += np.random.normal(0, noise_std * 0.05)

    xd   = setpoint.get('x',   0.0)
    yd   = setpoint.get('y',   0.0)
    zd   = setpoint.get('z',   1.0)
    psid = setpoint.get('psi', 0.0)

    # ── Altitude PD ──────────────────────────────────────────
    # Desired thrust to achieve target altitude + damp vertical velocity
    e_z   = zd - z
    T_des = MASS * (G + KP_Z * e_z - KD_Z * vz)
    T_max = 4.0 * K_THRUST * OMEGA_MAX**2
    T_des = float(np.clip(T_des, 0.0, T_max))

    # ── Position → desired tilt ──────────────────────────────
    # Rotate position error into body-heading frame before generating
    # angle commands so the drone moves in the correct direction
    cy, sy    = np.cos(psi), np.sin(psi)
    ex_world  = xd - x
    ey_world  = yd - y
    ex_body   =  cy * ex_world + sy * ey_world
    ey_body   = -sy * ex_world + cy * ey_world

    # Small-angle: desired acceleration ≈ g·tan(θ) ≈ g·θ
    theta_des = float(np.clip( (KP_XY * ex_body - KD_XY * vx) / G, -MAX_TILT, MAX_TILT))
    phi_des   = float(np.clip(-(KP_XY * ey_body - KD_XY * vy) / G, -MAX_TILT, MAX_TILT))

    # ── Attitude PD ───────────────────────────────────────────
    tau_phi   = KP_ATT * (phi_des   - phi)   - KD_ATT * p
    tau_theta = KP_ATT * (theta_des - theta) - KD_ATT * q
    tau_psi   = KP_YAW * (psid      - psi)   - KD_YAW * r

    return T_des, tau_phi, tau_theta, tau_psi


# ═══════════════════════════════════════════════════════════
# SIMULATION LOOP
# ═══════════════════════════════════════════════════════════

def run_simulation(t_end, dt, setpoint_fn, state0=None, noise_std=0.005, label=''):
    """
    Main simulation loop.

    Parameters
    ----------
    t_end       : float       total simulation time (s)
    dt          : float       RK4 timestep (s)
    setpoint_fn : callable    f(t) → dict{'x','y','z','psi'}
    state0      : ndarray(12) initial state  (default: rest at origin)
    noise_std   : float       sensor noise std-dev (m)
    label       : str         display name for progress output

    Returns
    -------
    t_hist     : ndarray(N,)
    state_hist : ndarray(N, 12)
    """
    if state0 is None:
        state0 = np.zeros(12)

    steps      = int(t_end / dt)
    t_hist     = np.linspace(0, t_end, steps, endpoint=False)
    state_hist = np.empty((steps, 12))

    state = state0.copy()
    print(f"  [{label}]  {t_end}s  dt={dt}s  (~{steps} steps) …")

    for i in range(steps):
        t  = t_hist[i]
        sp = setpoint_fn(t)

        # PD controller → desired wrench
        T, tau_phi, tau_theta, tau_psi = controller(state, sp, noise_std)

        # Wrench → motor speeds
        omegas = motor_speeds_from_wrench(T, tau_phi, tau_theta, tau_psi)

        # RK4 integration
        state = rk4_step(state, omegas, dt)

        # Ground constraint: drone cannot go below z = 0
        if state[2] < 0.0:
            state[2] = 0.0
            if state[8] < 0.0:
                state[8] = 0.0

        state_hist[i] = state

    print(f"         → done.  final z = {state_hist[-1, 2]:.3f} m")
    return t_hist, state_hist


# ═══════════════════════════════════════════════════════════
# SETPOINT PROFILES
# ═══════════════════════════════════════════════════════════

def sp_hover(t):
    """Hover at 1 m altitude, origin."""
    return {'x': 0.0, 'y': 0.0, 'z': 1.0, 'psi': 0.0}


def sp_circle(t, radius=2.0, speed=0.5):
    """
    Circular trajectory in the XY plane at z = 1.5 m.

    The drone takes 4 s to climb before entering the circle.
    Desired yaw tracks the tangent direction.
    """
    omega = speed / radius                          # angular rate (rad/s)
    t_fly = max(t - 4.0, 0.0)                       # delay for climb
    angle = omega * t_fly
    return {
        'x':   radius * np.cos(angle),
        'y':   radius * np.sin(angle),
        'z':   1.5,
        'psi': angle + np.pi/2,                     # tangent heading
    }


def sp_lshaped(t):
    """
    L-shaped waypoint mission:
      0–5 s   : take off to 1.5 m
      5–22 s  : fly to (3, 0) m
      22–38 s : fly to (3, 3) m
      38+ s   : land
    """
    if   t < 5.0:  return {'x': 0.0, 'y': 0.0, 'z': 1.5, 'psi': 0.0}
    elif t < 22.0: return {'x': 3.0, 'y': 0.0, 'z': 1.5, 'psi': 0.0}
    elif t < 38.0: return {'x': 3.0, 'y': 3.0, 'z': 1.5, 'psi': np.pi/2}
    else:          return {'x': 3.0, 'y': 3.0, 'z': 0.0, 'psi': np.pi/2}


# ═══════════════════════════════════════════════════════════
# PLOTTING
# ═══════════════════════════════════════════════════════════

# Light theme palette (matches portfolio site)
C = {
    'bg':      '#f6f8fa',
    'panel':   '#ffffff',
    'border':  '#d0d7de',
    'blue':    '#0969da',
    'green':   '#1a7f37',
    'orange':  '#bc4c00',
    'red':     '#cf222e',
    'purple':  '#7c3aed',
    'pink':    '#db2777',
    'text':    '#1f2328',
    'muted':   '#636c76',
}


def _style_axes(fig, axes):
    """Apply light theme to a list of Axes (2-D or 3-D)."""
    fig.patch.set_facecolor(C['bg'])
    for ax in axes:
        ax.set_facecolor(C['panel'])
        for spine in getattr(ax, 'spines', {}).values():
            spine.set_edgecolor(C['border'])
        ax.tick_params(colors=C['text'], labelsize=8)
        ax.xaxis.label.set_color(C['text'])
        ax.yaxis.label.set_color(C['text'])
        if hasattr(ax, 'zaxis'):
            ax.zaxis.label.set_color(C['text'])
            ax.zaxis.set_tick_params(labelcolor=C['text'], labelsize=8)
            ax.set_facecolor(C['panel'])
        ax.title.set_color(C['text'])
        ax.grid(True, color=C['border'], linewidth=0.6, linestyle='--')


def plot_results(t, states, title='Quadcopter Simulation'):
    """
    Generate a 6-panel figure:
      1. Position vs time
      2. Euler angles vs time
      3. Linear velocity vs time
      4. 3-D trajectory
      5. Top-down (XY) path
      6. Total speed vs time
    """
    x,  y,  z  = states[:,0], states[:,1], states[:,2]
    ph, th, ps = (np.degrees(states[:,3]),
                  np.degrees(states[:,4]),
                  np.degrees(states[:,5]))
    vx, vy, vz = states[:,6], states[:,7], states[:,8]
    speed      = np.sqrt(vx**2 + vy**2 + vz**2)

    fig = plt.figure(figsize=(16, 10))
    fig.suptitle(title, fontsize=13, fontweight='bold')

    # 1 – Position
    ax1 = fig.add_subplot(2, 3, 1)
    ax1.plot(t, x, color=C['blue'],   lw=1.8, label='x')
    ax1.plot(t, y, color=C['green'],  lw=1.8, label='y')
    ax1.plot(t, z, color=C['orange'], lw=1.8, label='z')
    ax1.set_xlabel('Time (s)'); ax1.set_ylabel('Position (m)')
    ax1.set_title('Position vs Time')
    ax1.legend(fontsize=8, facecolor=C['panel'], edgecolor=C['border'])

    # 2 – Euler angles
    ax2 = fig.add_subplot(2, 3, 2)
    ax2.plot(t, ph, color=C['blue'],   lw=1.8, label='φ roll')
    ax2.plot(t, th, color=C['green'],  lw=1.8, label='θ pitch')
    ax2.plot(t, ps, color=C['orange'], lw=1.8, label='ψ yaw')
    ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Angle (°)')
    ax2.set_title('Euler Angles vs Time')
    ax2.legend(fontsize=8, facecolor=C['panel'], edgecolor=C['border'])

    # 3 – Velocity
    ax3 = fig.add_subplot(2, 3, 3)
    ax3.plot(t, vx, color=C['blue'],   lw=1.8, label='vx')
    ax3.plot(t, vy, color=C['green'],  lw=1.8, label='vy')
    ax3.plot(t, vz, color=C['orange'], lw=1.8, label='vz')
    ax3.set_xlabel('Time (s)'); ax3.set_ylabel('Velocity (m/s)')
    ax3.set_title('Linear Velocity vs Time')
    ax3.legend(fontsize=8, facecolor=C['panel'], edgecolor=C['border'])

    # 4 – 3-D trajectory
    ax4 = fig.add_subplot(2, 3, 4, projection='3d')
    sc  = ax4.scatter(x, y, z, c=t, cmap='Blues_r', s=1.5, linewidths=0)
    ax4.set_xlabel('X (m)'); ax4.set_ylabel('Y (m)'); ax4.set_zlabel('Z (m)')
    ax4.set_title('3-D Trajectory')
    cb  = fig.colorbar(sc, ax=ax4, pad=0.12, shrink=0.7)
    cb.set_label('Time (s)', color=C['text'], fontsize=8)
    cb.ax.yaxis.set_tick_params(color=C['text'], labelsize=7)
    plt.setp(cb.ax.yaxis.get_ticklabels(), color=C['text'])

    # 5 – Top-down path
    ax5 = fig.add_subplot(2, 3, 5)
    ax5.plot(x, y, color=C['blue'], lw=1.8)
    ax5.plot(x[0], y[0], 'o', color=C['green'], ms=7, label='start')
    ax5.plot(x[-1], y[-1], 'x', color=C['red'],   ms=7, label='end')
    ax5.set_xlabel('X (m)'); ax5.set_ylabel('Y (m)')
    ax5.set_title('Top-Down Path')
    ax5.set_aspect('equal', adjustable='datalim')
    ax5.legend(fontsize=8, facecolor=C['panel'], edgecolor=C['border'])

    # 6 – Speed
    ax6 = fig.add_subplot(2, 3, 6)
    ax6.plot(t, speed, color=C['purple'], lw=1.8)
    ax6.set_xlabel('Time (s)'); ax6.set_ylabel('Speed (m/s)')
    ax6.set_title('Total Speed vs Time')

    _style_axes(fig, [ax1, ax2, ax3, ax4, ax5, ax6])
    plt.tight_layout()
    return fig


# ═══════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════

if __name__ == '__main__':
    DT    = 0.01     # RK4 timestep (s)
    T_END = 120.0    # simulation duration (s)

    print()
    print("╔══════════════════════════════════════╗")
    print("║  Quadcopter 6-DOF Flight Simulator   ║")
    print("╚══════════════════════════════════════╝")
    print(f"  Mass  = {MASS} kg     Arm = {ARM} m")
    print(f"  Ix,Iy = {Ix} kg·m²   Iz = {Iz} kg·m²")
    print(f"  Hover speed ≈ {np.sqrt(MASS*G/(4*K_THRUST)):.1f} rad/s per motor")
    print()

    # ── Simulation 1: Hover ──────────────────────────────────
    t1, s1 = run_simulation(
        T_END, DT, sp_hover, noise_std=0.005, label='Hover at 1 m'
    )
    fig1 = plot_results(t1, s1, 'Simulation 1 — Hover at 1 m')

    # ── Simulation 2: Circular trajectory ───────────────────
    t2, s2 = run_simulation(
        T_END, DT, sp_circle, noise_std=0.005,
        label='Circular trajectory (r=2 m, v=0.5 m/s)'
    )
    fig2 = plot_results(t2, s2, 'Simulation 2 — Circular Trajectory (r=2 m)')

    # ── Simulation 3: L-shaped waypoints ─────────────────────
    t3, s3 = run_simulation(
        50.0, DT, sp_lshaped, noise_std=0.005,
        label='L-shaped waypoint mission'
    )
    fig3 = plot_results(t3, s3, 'Simulation 3 — L-Shaped Waypoint Mission')

    print()
    print("All simulations complete — displaying plots.")
    plt.show()
