
from collections import namedtuple
import math
import numpy as np
import ubitstr
PI = math.pi

################################################################################
# general & util

Signal = namedtuple('Signal', ['t', 'y', 'bptu'])
ModSymbol = namedtuple('ModSymbol', ['I', 'Q', 'bits'])

def visualize_ascii(text: str):
    print(f'{"Char":<6} {"Hex":<6} {"Binary"}')
    print(f'{"-"*4:<6} {"-"*4:<6} {"-"*8}')
    for ch in text:
        code = ord(ch)
        print(f'{ch:<6} {f"0x{code:02X}":<6} {code:08b}')


def freq_filter(f, cutoff):
    """Sharp frequency filter with cutoff frequency"""
    f2 = abs((2*f)/cutoff)
    if f2 <= 1:
        return 1.0
    elif f2 > 2:
        return 0.0
    else:
        #return -1
        a = np.cos(0.5*PI*(f2-1))**0.25
        #print(f"{f2:.3f} -> {a:.6f}")
        return a


def lowpass(f, alpha=1.2, cutoff = 1):
    """Approximated physical copper-wire lowpass"""
    return math.exp(-alpha * math.sqrt(max(abs(f) - cutoff, 0)))


################################################################################
# impulse defs

IMPULSE_DEFS: dict[str, list[tuple]] = {
    "RECT":      [(-0.5,  0.5,  1)],
    "RZ":        [(-0.5,  0.0,  1),
( 0.0, 0.5, 0)],
    "Manchester":[(-0.5,  0.0, -1),
( 0.0, 0.5, 1)],
}
IMPULSES = list(IMPULSE_DEFS.keys())

def impulse_by_name(name: str) -> dict[str, list[tuple]]:
    if isinstance(name, str):
        return name, IMPULSE_DEFS[name]
    return "", name


################################################################################
# linecode defs

LINECODES = ["NRZ", "RZ", "Manchester", "MLT-3"]#, "5-PAM"]


################################################################################
# modulation defs

_s = math.sqrt(2) / 2  # sin/cos 45°
CONSTELLATION_DEFS: dict[str, list[tuple]] = {
    "2-ASK": [
        (-1.5, 0, "0"),
        ( 1.5, 0, "1"),
    ],
    "4-ASK": [
        (-1.5, 0, "00"),
        (-0.5, 0, "01"),
        ( 0.5, 0, "11"),
        ( 1.5, 0, "10"),
    ],
    "8-ASK": [
        (-1.75, 0, "000"),
        (-1.25, 0, "001"),
        (-0.75, 0, "011"),
        (-0.25, 0, "010"),
        ( 0.25, 0, "110"),
        ( 0.75, 0, "111"),
        ( 1.25, 0, "101"),
        ( 1.75, 0, "100"),
    ],
    "BPSK": [
        (-1, 0, "0"),
        ( 1, 0, "1"),
    ],
    "4-PSK": [
        ( _s,  _s, "00"),
        (-_s,  _s, "01"),
        (-_s, -_s, "11"),
        ( _s, -_s, "10"),
    ],
    "4-QAM": [
        (-1,  1, "00"),
        ( 1,  1, "01"),
        (-1, -1, "10"),
        ( 1, -1, "11"),
    ],
    "12-QAM": [
        (-0.5,  1.5, "0000"),
        ( 0.5,  1.5, "0001"),
        (-1.5,  0.5, "0010"),
        (-0.5,  0.5, "0011"),
        ( 0.5,  0.5, "0100"),
        ( 1.5,  0.5, "0101"),
        (-1.5, -0.5, "0110"),
        (-0.5, -0.5, "0111"),
        ( 0.5, -0.5, "1000"),
        ( 1.5, -0.5, "1001"),
        (-0.5, -1.5, "1010"),
        ( 0.5, -1.5, "1011"),
    ],
    "16-QAM": [
        (-1.5,  1.5, "0000"),
        (-0.5,  1.5, "0001"),
        ( 0.5,  1.5, "0011"),
        ( 1.5,  1.5, "0010"),
        (-1.5,  0.5, "0100"),
        (-0.5,  0.5, "0101"),
        ( 0.5,  0.5, "0111"),
        ( 1.5,  0.5, "0110"),
        (-1.5, -0.5, "1100"),
        (-0.5, -0.5, "1101"),
        ( 0.5, -0.5, "1111"),
        ( 1.5, -0.5, "1110"),
        (-1.5, -1.5, "1000"),
        (-0.5, -1.5, "1001"),
        ( 0.5, -1.5, "1011"),
        ( 1.5, -1.5, "1010"),
    ],
    "32-QAM": [
        (-1.5,  2.5, "00000"),
        (-0.5,  2.5, "00001"),
        ( 0.5,  2.5, "00010"),
        ( 1.5,  2.5, "00011"),
        (-2.5,  1.5, "00100"),
        (-1.5,  1.5, "00101"),
        (-0.5,  1.5, "00110"),
        ( 0.5,  1.5, "00111"),
        ( 1.5,  1.5, "01000"),
        ( 2.5,  1.5, "01001"),
        (-2.5,  0.5, "01010"),
        (-1.5,  0.5, "01011"),
        (-0.5,  0.5, "01100"),
        ( 0.5,  0.5, "01101"),
        ( 1.5,  0.5, "01110"),
        ( 2.5,  0.5, "01111"),
        (-2.5, -0.5, "10000"),
        (-1.5, -0.5, "10001"),
        (-0.5, -0.5, "10010"),
        ( 0.5, -0.5, "10011"),
        ( 1.5, -0.5, "10100"),
        ( 2.5, -0.5, "10101"),
        (-2.5, -1.5, "10110"),
        (-1.5, -1.5, "10111"),
        (-0.5, -1.5, "11000"),
        ( 0.5, -1.5, "11001"),
        ( 1.5, -1.5, "11010"),
        ( 2.5, -1.5, "11011"),
        (-1.5, -2.5, "11100"),
        (-0.5, -2.5, "11101"),
        ( 0.5, -2.5, "11110"),
        ( 1.5, -2.5, "11111"),
    ],
}
MODULATIONS = list(CONSTELLATION_DEFS.keys())


################################################################################
# channel en/decode
# example implementation of a linear block code

def _parity_matrix(k: int, m: int) -> np.ndarray:
    """Build the (m x k) parity sub-matrix P for a systematic (n,k) code.
    Columns are distinct non-zero m-bit vectors, skipping the m identity columns
    (powers of 2) so that H = [P | I_m] has all non-zero columns — Hamming property.
    """
    if m == 1:
        return np.ones((1, k), dtype=int)
    identity_cols = {1 << i for i in range(m)}
    cols, v = [], 1
    while len(cols) < k:
        if v not in identity_cols:
            cols.append(v)
        v += 1
    P = np.zeros((m, k), dtype=int)
    for j, col in enumerate(cols):
        for i in range(m):
            P[i, j] = (col >> i) & 1
    return P


def encode(bitstr: str, k: int, n: int) -> str:
    """
    Takes a bitstring and adds m=n-k parity bits per k-bit block -> n-bit codewords.
    Input is zero-padded to a multiple of k if needed.
    """
    m = n - k
    bits = ubitstr.to_bits(bitstr)
    bits = bits + [False] * ((-len(bits)) % k)
    P = _parity_matrix(k, m)
    result = []
    for i in range(0, len(bits), k):
        block = np.array([1 if b else 0 for b in bits[i:i+k]], dtype=int)
        parity = (P @ block) % 2
        result.extend(block.tolist())
        result.extend(parity.tolist())
    raw = ''.join(str(b) for b in result)
    return ' '.join(raw[i:i+n] for i in range(0, len(raw), n))


def decode(encoded_bitstr: str, k: int, n: int) -> str:
    """
    Decodes a bitstring that was encoded with encode().
    Corrects single-bit errors when m=n-k >= 3 via syndrome lookup.
    """
    m = n - k
    bits = ubitstr.to_bits(encoded_bitstr)
    P = _parity_matrix(k, m)
    H = np.hstack([P, np.eye(m, dtype=int)])  # (m × n)
    result = []
    for i in range(0, len(bits), n):
        word = np.array([1 if b else 0 for b in bits[i:i+n]], dtype=int)
        syndrome = (H @ word) % 2
        if m >= 3 and np.any(syndrome):
            for j in range(n):
                if np.all(H[:, j] == syndrome):
                    word[j] ^= 1
                    break
        result.extend(word[:k].tolist())
    return ubitstr.to_bitstr([b == 1 for b in result])



################################################################################
# mod/demod

def demodulate_signal(sig: Signal, scheme: str, carrier_freq: float,
               impulse: str = 'rect') -> list[ModSymbol]:
    """Coherent IQ demodulation with matched filter over the full symbol period T."""
    t, signal = sig.t, sig.y
    n_symbols = int(round(t[-1] - t[0]))
    sps       = len(t) // n_symbols

    I_raw = signal *  np.cos(2 * np.pi * carrier_freq * t) * 2
    Q_raw = signal * -np.sin(2 * np.pi * carrier_freq * t) * 2

    w_s = int(round(0 * sps))
    w_e = int(round(1 * sps))

    tau_full = np.linspace(0, 1, sps, endpoint=False)
    if impulse == 'cos2':
        g_full = np.cos(np.pi * (tau_full - 0.5)) ** 2
    else:
        g_full = np.ones(sps)
    g = g_full[w_s:w_e]
    g_energy = np.dot(g, g)

    points = CONSTELLATION_DEFS[scheme]
    results = []
    for i in range(n_symbols):
        s    = i * sps + w_s
        e    = i * sps + w_e
        I_rx = float(np.dot(I_raw[s:e], g) / g_energy)
        Q_rx = float(np.dot(Q_raw[s:e], g) / g_energy)
        best = min(points, key=lambda p: (p[0] - I_rx) ** 2 + (p[1] - Q_rx) ** 2)
        results.append(ModSymbol(I_rx, Q_rx, best[2]))

    return results