import numpy as np
import matplotlib.pyplot as plt
import os, termcolor, matplotlib, configparser
from pydarm.darm import DARMModel
import scipy.signal as signal
import foton, ipdb

matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams.update({'figure.figsize': (14,9)})

def asd_to_rms(f, asd, reverse=False):
    rms_curve = np.zeros_like(asd)
    if not reverse:
        for i in range(len(f)-1, -1, -1):
            rms = np.sqrt(np.sum(asd[i:]**2))
            rms_curve[i] = rms
    else:
        for i in range(0, len(f)):
            rms = np.sqrt(np.sum(asd[:i]**2))
            rms_curve[i] = rms
    print()
    return rms_curve

def read_tf_txt(filename, comment_char='#',
                delimiter = " ", freq_sep = " ",
                ):
    
    file = open(filename, 'r')
    lines = file.readlines()
    
    nlines = sum([0 if line[0] == comment_char else 1 for line in lines])
    print(termcolor.colored(f"\nReading from {file.name} ...", color='green'))
    print(termcolor.colored(f"{nlines} data points found.", color='white'))
    
    comments = []
    i = 0
    while True:
        line = lines[i]
        if line[0] != comment_char:
            break
        else:
            comments.append(line[1:-1].strip())
            i += 1

    main_body = lines[i:]
    
    print('Comments:')
    for each in comments[:-1]:
        print('\t', each)

    subdata = [sub.strip() for sub in line.split(freq_sep)]
    to_skip = []
    for j in range(len(subdata)):
        if subdata[j] == '':
            to_skip.append(j)
    
    ntfs = (len(subdata)-1 - len(to_skip))//2
    freq = np.ndarray((nlines, 1))
    data = np.ndarray((nlines, ntfs*2))
    for j in range(len(main_body)):
        subdata = [sub.strip() for sub in main_body[j].split(freq_sep)]
        to_skip = []
        for k in range(len(subdata)):
            if subdata[k] == '':
                to_skip.append(k)
        subdata = [subdata[k] for k in range(len(subdata)) if k not in to_skip]
        freq[j] = float(subdata[0])
        data[j] = [float(subdata[k+1]) for k in range(len(subdata[1:]))]

    file.close()

    return freq.T.flatten(), data

if __name__ == '__main__':

    freqs, data = read_tf_txt(os.path.join(os.path.dirname(__file__), 'co2x_ac_mon_20250204.txt'))

    labels = [
        "Laser CW - Locked", "Laser CW - Unlocked",
        "Laser PWM - 25kHz 50%", "Laser PWM - 5kHz 50%",
        "Laser PWM - 10kHz 50%"
    ]

    dc_cts = [
        [5562, 6336],
        [5562, 6336],
        [3553, 3890],
        [4010, 4381],
        [3816, 4143]
    ]
    _, H = signal.freqs_zpk([0], [20], 10**(105.4/20), worN=freqs*2*np.pi)

    fig = plt.figure(figsize=(14, 9))
    axs = fig.subplots(1, 1)

    order = [0, 1, 3, 4, 2]

    for i in range(len(order)):

        in_asd, out_asd = data[:, order[i]], data[:, order[i]+len(labels)]

        uncalibrated_in_asd = np.abs(in_asd / H)
        uncalibrated_out_asd = np.abs(out_asd / H)

        uncalibrated_in_dc = dc_cts[i][0] / 510 
        uncalibrated_out_dc =  dc_cts[i][1] / 510

        rin_in = uncalibrated_in_asd / uncalibrated_in_dc
        rin_out = uncalibrated_out_asd / uncalibrated_out_dc

        axs.loglog(freqs, rin_in, label=labels[order[i]]+' - In', linewidth=.1,
                   color = f"C{i}")
        axs.loglog(freqs, rin_out, label=labels[order[i]]+' - Out', linewidth=.1,
                   ls = '--', color = f"C{i}")
        axs.loglog(freqs, asd_to_rms(freqs, rin_in), ls= ':',
                   linewidth=1, color=f'C{i}')

    axs.set_xlabel("Frequency [Hz]")
    axs.set_ylabel("RIN (cts/cts)/rtHz")

    axs.grid(True, 'major', 'both', alpha=.5)
    axs.grid(True, 'minor', 'both', alpha=.2)
    axs.legend(fontsize=11)

    fig.suptitle("RIN CO2X Laser OUT : 04 February 2025 : H1:IOP-OAF_L0_MADC3_TP_CH10/12")

    fig.savefig(os.path.join(os.path.dirname(__file__), 'figures',
                             '20250114_co2x_rin.pdf'))
    
    
