import math
from collections import namedtuple
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ubitstr
import defs
from defs import Signal, ModSymbol
import numpy as np
from numpy import fft


TUMBlue   = (0.00, 0.40, 0.74)
TUMRed    = (0.77, 0.03, 0.11)
TUMOrange = (0.89, 0.45, 0.13)
TUMGreen  = (0.64, 0.68, 0.00)
TUMDarkGreen  = (0.00,0.49,0.19)
TUMDarkYellow = (0.98,0.73,0.00)
FILL_ALPHA = 0.25
_CMAP_DIST = mcolors.LinearSegmentedColormap.from_list('dist', [TUMDarkGreen, TUMDarkYellow, TUMRed])


################################################################################
# bitstr

def bitstr(bitstr: str, xtick_interval: int = None, k: int = None):
    """
    Plots a bitstring with 0 and 1 amplitudes.
    :param xtick_interval: interval of xticks to group bits
    :param k: all bits at index > k in an interval between xticks are colored red; useful for channel coding
    """
    bits = ubitstr.to_bits(bitstr)
    amplitudes = [1.0 if b else 0.3 for b in bits]

    fig, ax = plt.subplots(figsize=(max(8, len(bits) * 0.4), 3))

    for i, amp in enumerate(amplitudes):
        # every bit > k in one xtick interval is special/partity -> red
        is_nth = k is not None and (i % (k + 1)) == k
        c = TUMRed if is_nth else TUMBlue

        ax.plot([i, i + 1], [amp, amp], color=c, linewidth=2)
        ax.fill_between([i, i + 1], 0, amp, color=(*c, FILL_ALPHA))

        # vertical connector at transition from previous bit
        if i > 0 and amplitudes[i] != amplitudes[i - 1]:
            ax.plot([i, i], [amplitudes[i - 1], amp], color=c, linewidth=2)

    ax.set_xlim(0, len(bits))
    ax.set_ylim(0, 1.3)
    ax.set_xlabel("Bit index")
    ax.set_yticks([0.3, 1.0])
    ax.set_yticklabels(["0", "1"])
    if xtick_interval is not None:
        ax.set_xticks(range(xtick_interval, len(bits) + 1, xtick_interval))
    plt.tight_layout()
    plt.show()


################################################################################
# impulse

def _draw_segments(ax, segments: list[tuple], color, pad=0.25):
    """Draw a step-function from segments onto ax with +-pad dashed tails."""
    t0, t1 = segments[0][0], segments[-1][1]
    ax.plot([t0 - pad, t0], [0, 0], '--', color=color, linewidth=2)
    ax.plot([t1, t1 + pad], [0, 0], '--', color=color, linewidth=2)

    prev_amp = 0
    for t_start, t_end, amp in segments:
        if amp != prev_amp:
            ax.plot([t_start, t_start], [prev_amp, amp], color=color, linewidth=2)
        ax.plot([t_start, t_end], [amp, amp], color=color, linewidth=2)
        prev_amp = amp
    if prev_amp != 0:
        ax.plot([t1, t1], [prev_amp, 0], color=color, linewidth=2)


def impulse(impulse_def, ax=None):
    """
    Plots the impulse definition between of a base impulse.
    :param impulse_def: name of an impulse to lookup or the constant segments itself (only constant impulses, no cos2)
    :param ax: axes to to if incorporated in an external figure.
    """
    name, segments = defs.impulse_by_name(impulse_def)
    standalone = ax is None
    if standalone:
        _, ax = plt.subplots()
    _draw_segments(ax, segments, TUMBlue, pad=0.25)
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_xticks([-0.5, 0, 0.5])
    ax.set_xticklabels(['-T/2', '0', 'T/2'])
    ax.set_yticks([-1, 0, 1])
    ax.axhline(0, color='black', linewidth=0.5)
    ax.grid(True, linewidth=0.5, linestyle='--', alpha=0.5)
    ax.set_aspect(1)
    if name:
        ax.set_title(name)
    if standalone:
        plt.tight_layout()
        plt.show()


################################################################################
# linecode

def _bitstr_to_levels(bitstr: str, linecode: str) -> list[tuple]:
    """Return a list of level segments (t_start, t_end, amplitude) for the encoded signal."""
    bits = ubitstr.to_bits(bitstr)
    segments = []

    if linecode == "NRZ":
        for t, b in enumerate(bits):
            segments.append((t, t + 1, 1 if b else -1))

    elif linecode == "RZ":
        for t, b in enumerate(bits):
            segments.append((t,       t + 0.5, 1 if b else -1))
            segments.append((t + 0.5, t + 1,   0))

    elif linecode == "Manchester":
        for t, b in enumerate(bits):
            # 1 bit: rising edge
            # 0 bit: falling edge
            segments.append((t,       t + 0.5, -1 if b else  1))
            segments.append((t + 0.5, t + 1,    1 if b else -1))

    elif linecode == "MLT-3":
        # cycle (0,1,0,-1)
        # advance on 1 bit
        cycle = [0, 1, 0, -1]
        state = 0
        for t, b in enumerate(bits):
            if b:
                state = (state + 1) % 4
            segments.append((t, t + 1, cycle[state]))

    return segments


def linecode(bitstr: str, linecode: str, invert: bool = False, xtick_interval: int = 8, bit_labels: bool = True):
    """
    Plot a bitstring with a specific linecode.
    :param bitstr: a bitstr.
    :param linecode: str name of the linecode.
    :param invert: inverts the signal amplitudes
    :param xtick_interval: distance of labelled xticks; should be a codeword length for highlighting.
    :param bit_labels: show bitlabels at each impulse if true
    """
    bits = ubitstr.to_bits(bitstr)
    segments = _bitstr_to_levels(bitstr, linecode)
    n_bits = len(bits)

    fig, ax = plt.subplots(figsize=(max(8, n_bits * 0.4), 3))

    prev_amp = None
    for t_start, t_end, amp in segments:
        amp = -amp if invert else amp
        color = TUMGreen if invert else TUMBlue
        if prev_amp is not None and prev_amp != amp:
            ax.plot([t_start, t_start], [prev_amp, amp], color=color, linewidth=2)
        ax.plot([t_start, t_end], [amp, amp], color=color, linewidth=2)
        ax.fill_between([t_start, t_end], 0, amp, color=(*color, FILL_ALPHA))
        prev_amp = amp

    # bit labels
    if bit_labels:
        for i, b in enumerate(bits):
            ax.text(i + 0.5, -1.2, '1' if b else '0',
                    ha='center', va='center', fontsize=9, color='black')

    ax.set_xlim(0, n_bits)
    ax.set_ylim(-1.4, 1.1)
    all_ticks = list(range(0, n_bits + 1))
    ax.set_xticks(all_ticks)
    ax.set_xticklabels([
        str(x) if x % xtick_interval == 0 else ''
        for x in all_ticks
    ])
    ax.axhline(0, color='black', linewidth=0.5)
    ax.xaxis.grid(True, linewidth=0.5, linestyle='--', alpha=0.5)
    ax.set_axisbelow(True)
    ax.set_title(linecode)
    plt.tight_layout()
    plt.show()

    # return signal for fft
    t_starts = np.array([seg[0] for seg in segments])
    t = np.linspace(0, n_bits, n_bits * 300, endpoint=False)
    idx = np.searchsorted(t_starts, t, side='right') - 1
    idx = np.clip(idx, 0, len(segments) - 1)
    y = np.array([segments[i][2] for i in idx])
    return Signal(t, y, 1)


def linecode_drift(bitstr: str, linecode: str, drift: float = 0.0,
                   sampling: bool = False, sampling_offset: float = 0.5,
                   repair_clock: bool = False,
                   invert: bool = False, xtick_interval: int = 8):
    """
    Like linecode() with clock drift; time is receiver time; signal is streched/compressed according to drift.
    :param drift: relative drive (zero is neutral) of receiver compared to sender.
    Positive drift: receivers unit of time > senders unit of time -> receiver is slower -> signal is finished faster
    :param sampling: whether to sample the signal (shown with red dots)
    :param sampling_offset: offset of the sampling point withing one symbol (0 at the start of the symbol, 1 at the end)
    :param repair_clock: places each symbol aligned with the receivers clock to to repair the shift
    """
    bits = ubitstr.to_bits(bitstr)
    segments = _bitstr_to_levels(bitstr, linecode)
    n_bits = len(bits)
    scale = 1.0 + drift

    if repair_clock:
        drifted = []
        for t_s, t_e, amp in segments:
            bit_idx = int(t_s + 1e-9)           # which receiver tick this segment belongs to
            local_s = t_s - bit_idx             # position within the bit [0, 1)
            local_e = min(t_e - bit_idx, 1.0)   # end within the bit, capped at 1
            drifted.append((bit_idx + local_s * scale, bit_idx + local_e * scale, amp))
    else:
        # continuous chained signal: scale all segment times uniformly
        drifted = [(t_s * scale, t_e * scale, amp) for t_s, t_e, amp in segments]

    fig_w = max(8, n_bits * 0.4)
    fig, ax = plt.subplots(figsize=(fig_w, 3))

    prev_amp = None
    prev_bit = None
    for seg_idx, (t_start, t_end, amp_raw) in enumerate(drifted):
        amp = -amp_raw if invert else amp_raw
        cur_bit = int(segments[seg_idx][0] + 1e-9)
        same_bit = (cur_bit == prev_bit)

        # vertical connector: always within same bit; also across bits when not repair_clock
        if prev_amp is not None and prev_amp != amp:
            if not repair_clock or same_bit:
                ax.plot([t_start, t_start], [prev_amp, amp], color=TUMBlue, linewidth=2)

        if repair_clock:
            boundary = float(cur_bit + 1)
            # Normal (non-overhanging) portion
            if t_start < boundary:
                ax.plot([t_start, min(t_end, boundary)], [amp, amp], color=TUMBlue, linewidth=2)
                ax.fill_between([t_start, min(t_end, boundary)], 0, amp,
                                color=(*TUMBlue, FILL_ALPHA))
            # overhanging portion (beyond receiver tick boundary)
            if t_end > boundary:
                overhang_s = max(t_start, boundary)
                ax.plot([overhang_s, t_end], [amp, amp], color='gray', linewidth=2)
                ax.fill_between([overhang_s, t_end], 0, amp,
                                color=(0.6, 0.6, 0.6, FILL_ALPHA))
        else:
            ax.plot([t_start, t_end], [amp, amp], color=TUMBlue, linewidth=2)
            ax.fill_between([t_start, t_end], 0, amp, color=(*TUMBlue, FILL_ALPHA))

        prev_amp = amp  # always carry forward - needed for intra-bit connectors
        prev_bit = cur_bit

    ax.set_xlim(0, n_bits)
    ax.set_ylim(-1.4, 1.1)
    all_ticks = list(range(0, n_bits + 1))
    ax.set_xticks(all_ticks)
    ax.set_xticklabels([str(x) if x % xtick_interval == 0 else '' for x in all_ticks])
    ax.axhline(0, color='black', linewidth=0.5)
    ax.xaxis.grid(True, linewidth=0.5, linestyle='--', alpha=0.5)
    ax.set_axisbelow(True)
    ax.set_title(f'{linecode}  (drift={drift:+.2f}{"  repair_clock" if repair_clock else ""})')

    sampled_amps = []
    if sampling:
        t_starts_d = np.array([seg[0] for seg in drifted])
        amps_d = np.array([(-seg[2] if invert else seg[2]) for seg in drifted])
        for i in range(n_bits):
            t_sample = i + sampling_offset
            idx = int(np.clip(np.searchsorted(t_starts_d, t_sample, side='right') - 1,
                              0, len(drifted) - 1))
            amp_at_sample = amps_d[idx]
            sampled_amps.append(amp_at_sample)
            ax.plot(t_sample, amp_at_sample, 'o', color=TUMRed, markersize=6, zorder=5)

    plt.tight_layout()
    plt.show()

    return sampled_amps


################################################################################
# spectrum & filtering

def _compute_spectrum(sig: Signal):
    """Compute the spectrum of a signal (result of a linecode or modulation plot) using fft"""
    from numpy import fft
    rate = len(sig.t) / (sig.t[-1] - sig.t[0])
    fy   = fft.fft(sig.y)
    freq = fft.fftfreq(len(fy), d=1 / rate)
    return freq, fy


def spectrum(sig: Signal, freq_filter=None, xlim: float = 10):
    """Plot the spectrum of a signal
    :param sig: the signal, as obtained from a linecode or modulation plot
    :param freq_filter: a filter function to limit the spectrum
    :param xlim: the +- xlimit
    """
    freq, fy = _compute_spectrum(sig)
    amp  = np.abs(fy)
    norm = amp.max()

    fig, ax = plt.subplots(figsize=(12, 3))
    ax.plot(freq, amp / norm, color=TUMBlue, linewidth=1.2, label='spectrum')
    if freq_filter is not None:
        filt = np.vectorize(freq_filter)(freq)
        ax.plot(freq, filt, linestyle="--", color=TUMRed, linewidth=1.2, label='filter')
        ax.plot(freq, np.abs(filt * fy) / norm, color=TUMOrange, linewidth=1.2, label='filtered')
        ax.legend(fontsize=8)
    ax.set_xlim(-xlim, xlim)
    ax.set_ylabel('Normalized Absolute Amplitude')
    ax.set_xlabel('Frequency')
    ax.grid(True, linewidth=0.5, linestyle='--', alpha=0.4)
    plt.tight_layout()
    plt.show()


def filtered(sig: Signal, freq_filter, xscale: float = 1.0, xtick_interval: int = 1):
    """Plots the signal after filtering with the filter function"""
    freq, fy  = _compute_spectrum(sig)
    sig_filt  = Signal(sig.t, np.real(fft.ifft(np.vectorize(freq_filter)(freq) * fy)), sig.bptu)

    fig_w = max(8, (sig.t[-1] - sig.t[0]) * sig.bptu * 0.4) * xscale
    fig, ax = plt.subplots(figsize=(fig_w, 3))
    ax.plot(sig.t,      sig.y,      color=TUMBlue,   linewidth=1.0, label='original')
    ax.plot(sig_filt.t, sig_filt.y, color=TUMOrange, linewidth=1.0, label='filtered')
    ax.axhline(0, color='black', linewidth=0.5)
    ax.grid(True, linewidth=0.5, linestyle='--', alpha=0.4)
    ax.set_ylabel('signal')
    ax.set_xlabel('t / T')
    ax.legend(fontsize=8)
    ax.set_xlim(0, sig.t[-1])

    t_max = int(math.ceil(sig.t[-1]))
    all_ticks = list(range(0, t_max + 1))
    ax.set_xticks(all_ticks)
    ax.set_xticklabels([
        str(x) if x % xtick_interval == 0 else ''
        for x in all_ticks
    ])

    plt.tight_layout()
    plt.show()
    return sig_filt


################################################################################
# constellation

def constellation(scheme: str, ax=None):
    """Plots a constallation diagram for the modulation scheme spec"""

    points = defs.CONSTELLATION_DEFS[scheme]
    standalone = ax is None
    if standalone:
        _, ax = plt.subplots()

    if 'PSK' in scheme:
        theta = [math.radians(t) for t in range(361)]
        ax.plot([math.cos(t) for t in theta], [math.sin(t) for t in theta],
                color='gray', linewidth=0.8, linestyle='--')

    for I, Q, bits in points:
        ax.plot(I, Q, 'o', color=TUMBlue, markersize=7, zorder=3)
        ax.annotate(bits, (I, Q), textcoords='offset points', xytext=(5, 4),
                    fontsize=7)

    lim = 3 if scheme == "32-QAM" else 2
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    ax.axhline(0, color='black', linewidth=0.5)
    ax.axvline(0, color='black', linewidth=0.5)
    ax.set_xlabel('I')
    ax.set_ylabel('Q')
    ax.set_aspect('equal')
    ax.set_title(scheme)
    ax.grid(True, linewidth=0.5, linestyle='--', alpha=0.4)
    if standalone:
        plt.tight_layout()
        plt.show()


################################################################################
# modulation

def _build_modulation_symbols(bitstr: str, scheme: str) -> list[tuple]:
    """Return (I, Q) weight per symbol by looking up each bit group in the constellation."""
    points = defs.CONSTELLATION_DEFS[scheme]
    nb = len(points[0][2])
    lookup = {bits: (I, Q) for I, Q, bits in points}
    raw = [c for c in bitstr if c in '01']
    return [lookup[chunk] for i in range(0, len(raw) - nb + 1, nb)
            if len(chunk := ''.join(raw[i:i + nb])) == nb]


def _apply_impulse(weights: list[float], t: np.ndarray, n: int, impulse: str) -> np.ndarray:
    sps    = len(t) // n
    result = np.zeros(len(t))
    for i, w in enumerate(weights):
        s, e = i * sps, (i + 1) * sps
        if impulse == 'cos2':
            tau          = np.linspace(0, 1, e - s, endpoint=False)
            result[s:e]  = w * np.cos(np.pi * (tau - 0.5)) ** 2
        else: # rect (default)
            result[s:e]  = w
    return result


def modulation(bitstr: str, scheme: str, carrier_freq: float = 2,
                   xscale: float = 1.0, impulse: str = 'rect'):
    """Modulate the bitstring with the given modulation scheme
    :param bitstr: data to modulate
    :param scheme: modulation scheme name
    :param carrier_freq: relative frequency of the carrier wave to a symbol
    :param xscale: scaling factor for better readability
    :param impulse: base impulses to use (cos2 or rect)
    :param: the modulated signal for e.g. filtering
    """
    symbols  = _build_modulation_symbols(bitstr, scheme)
    n        = len(symbols)
    nb       = len(defs.CONSTELLATION_DEFS[scheme][0][2])   # bits per symbol
    has_q    = any(Q != 0 for _, Q in symbols)
    t        = np.linspace(0, n, n * 300, endpoint=False)

    I_weights = [s[0] for s in symbols]
    Q_weights = [s[1] for s in symbols]
    I_bb  = _apply_impulse(I_weights, t, n, impulse)
    Q_bb  = _apply_impulse(Q_weights, t, n, impulse)
    cos_carrier = np.cos(2 * np.pi * carrier_freq * t)
    sin_carrier = np.sin(2 * np.pi * carrier_freq * t)
    I_mod = I_bb * cos_carrier
    Q_mod = Q_bb * sin_carrier

    # Match linecode width: each bit is 0.4 in, each symbol with nb bits per symbol is nb * 0.4 in
    fig_w = max(8, n * nb * 0.4) * xscale

    def _style(ax, ylabel):
        ax.axhline(0, color='black', linewidth=0.5)
        ax.set_xlim(0, n)
        ax.set_xticks(range(0, n + 1))
        ax.xaxis.grid(True, linewidth=0.5, linestyle='--', alpha=0.4)
        ax.set_ylabel(ylabel)

    if has_q:
        fig, axes = plt.subplots(3, 1, figsize=(fig_w, 8), sharex=True)
        axes[0].plot(t, cos_carrier, color='gray',  linewidth=0.8, alpha=0.6)
        axes[0].plot(t, I_bb,  color=TUMBlue, linewidth=1.5, linestyle='--')
        axes[0].plot(t, I_mod, color=TUMBlue, linewidth=1.5);  _style(axes[0], 'I')
        axes[1].plot(t, sin_carrier, color='gray',  linewidth=0.8, alpha=0.6)
        axes[1].plot(t, Q_bb,  color=TUMRed,  linewidth=1.5, linestyle='--')
        axes[1].plot(t, Q_mod, color=TUMRed,  linewidth=1.5);  _style(axes[1], 'Q')
        axes[2].plot(t, I_mod - Q_mod, color='black', linewidth=1.5)
        _style(axes[2], 'I - Q  (result)')
    else:
        fig, axes = plt.subplots(1, 1, figsize=(fig_w, 3))
        axes = [axes]
        axes[0].plot(t, cos_carrier, color='gray',  linewidth=0.8, alpha=0.6)
        axes[0].plot(t, I_bb,  color=TUMBlue, linewidth=1.5, linestyle='--')
        axes[0].plot(t, I_mod, color=TUMBlue, linewidth=1.5);  _style(axes[0], 'I')

    axes[-1].set_xlabel('t / T')
    fig.suptitle(f'{scheme}  -  $f_carrier$ = {carrier_freq}')
    plt.tight_layout()
    plt.show()

    result = I_mod - Q_mod if has_q else I_mod
    return Signal(t, result, nb)


def demodulation(received: list[ModSymbol], scheme: str) -> str:
    """Scatter received symbols onto the constellation diagram; return decoded bit string."""
    points = defs.CONSTELLATION_DEFS[scheme]

    fig, ax = plt.subplots(figsize=(8, 8))

    if 'PSK' in scheme:
        theta = [math.radians(d) for d in range(361)]
        ax.plot([math.cos(d) for d in theta], [math.sin(d) for d in theta],
                color='gray', linewidth=0.8, linestyle='--')

    for I, Q, bits in points:
        ax.plot(I, Q, 'o', color=TUMBlue, markersize=7, zorder=3)
        ax.annotate(bits, (I, Q), textcoords='offset points', xytext=(5, 4), fontsize=7)

    points_xy = [(I, Q) for I, Q, _ in points]
    distances = [
        math.sqrt(min((s.I - I)**2 + (s.Q - Q)**2 for I, Q in points_xy))
        for s in received
    ]
    sc = ax.scatter(
        [s.I for s in received],
        [s.Q for s in received],
        c=distances, cmap=_CMAP_DIST, s=80, marker='x',
        linewidths=2, zorder=4, vmin=0, vmax=0.9,
    )
    plt.colorbar(sc, ax=ax, label='distance to nearest point', fraction=0.046, pad=0.04)

    lim = 3 if scheme == "32-QAM" else 2
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    ax.set_aspect('equal')
    ax.set_xlabel('I')
    ax.set_ylabel('Q')
    ax.set_title(f'{scheme} - received (x) vs constellation (o)')
    ax.grid(True, linewidth=0.5, linestyle='--', alpha=0.4)
    plt.tight_layout()
    plt.show()

    all_bits = [b == '1' for s in received for b in s.bits]
    return ubitstr.to_bitstr(all_bits)