import dttxml 
import numpy as np
import matplotlib.pyplot as plt



def import_dtt_calib(freq):
    #we need this for the OMC DCPD whitening chassis parameters, should only need to update for times when that has changed.
    #we can't use the newer .ini files because the pydarm version changed, so instead I'm importing a text file that Louis made for the dtt calibration after the Jan 2023 OMC whitening chassis change
    #model_file = '/ligo/gitcommon/Calibration/ifo/pydarmparams/pydarm_modelparams_PostO3_H1_20220527.ini'
    #calcs_model = calcs.CALCSModel(model_file)
    #tf = calcs_model.calcs_dtt_calibration(freq)
    f = np.loadtxt('/ligo/gitcommon/NoiseBudget/aligoNB/aligoNB/H1/calibrations/calcs_dtt_calibration_model_20230125_freqs.txt')
    tf = np.loadtxt('/ligo/gitcommon/NoiseBudget/aligoNB/aligoNB/H1/calibrations/calcs_dtt_calibration_model_20230125_tf.txt', dtype=np.complex128)
    tf_mag = np.abs(tf)
    tf_phase = np.angle(tf)
    tf_mag = np.interp(freq, f, tf_mag)
    tf_phase = np.interp(freq, f, tf_phase)
    tf = tf_mag #* np.exp(tf_phase)
    return tf



def excess_power_coupling(
    f_budget,
    freq,
    wit_quiet_chan, 
    wit_exc_chan,
    target_quiet_chan, 
    target_exc_chan,
    dtt_obj,
):
    """Compute excess power coupling to a target channel, normally DARM, in ASD

    """


    psd_wit_quiet = dtt_obj.asd(wit_quiet_chan).asd**2
    psd_wit_exc = dtt_obj.asd(wit_exc_chan).asd**2
    psd_target_quiet = dtt_obj.asd(target_quiet_chan).asd**2
    psd_target_exc = dtt_obj.asd(target_exc_chan).asd**2

    # Coupling of witness channel PSD to DARM, before masking or interpolation
    #excess_psd_target = psd_target_exc - psd_target_quiet
    #excess_psd_wit = psd_wit_exc - psd_wit_quiet
    cpl_proj = (psd_target_exc - psd_target_quiet )/ (psd_wit_exc - psd_wit_quiet)
    
    #interpolate coupling estimate, and measurements so that you can decide which frequencies to mask based on interpolated data
    psd_target_exc = np.interp(f_budget, freq, psd_target_exc)
    psd_target_quiet = np.interp(f_budget, freq, psd_target_quiet)
    psd_wit_exc = np.interp(f_budget, freq, psd_wit_exc)
    psd_wit_quiet = np.interp(f_budget, freq, psd_wit_quiet)
    coupling = np.interp(f_budget, freq, cpl_proj)

    # Remove frequency bins where witness is <1.7x reference (little noise injection)
    # or where DARM is <2x reference (projection is really an upper limit)
    mask = np.where(((psd_wit_exc / psd_wit_quiet) < 1.7**2) |
        ((psd_target_exc / psd_target_quiet) < 2**2), 
    )
    #I don't understand why mask is a tuple..
    #mask = mask[0]
    coupling[mask] = np.nan
    return np.sqrt(coupling)

def excess_power_projection(
    f_budget,
    freq,
    wit_quiet_chan, 
    wit_exc_chan,
    target_quiet_chan, 
    target_exc_chan,
    dtt_obj,
):
    """Compute excess power PSD coupling to a target channel, normally DARM

    """

    coupling =  excess_power_coupling(f_budget, freq, wit_quiet_chan, wit_exc_chan, target_quiet_chan, target_exc_chan, dtt_obj)
    witness_quiet_asd = dtt_obj.asd(wit_quiet_chan).asd
    projection = coupling * witness_quiet_asd
    return projection


PRCL_dtt = dttxml.DiagAccess("PRCL_excitation.xml")
#get a frequency vector and a DARM calibration TF
placeholder = PRCL_dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ')
freq = placeholder.FHz
darm_cal = import_dtt_calib(freq)

#make the PRCL projection to DARM
#prcl_to_darm_coupling = excess_power_coupling(freq, freq, 'H1:LSC-PRCL_OUT_DQ(REF8)', 'H1:LSC-PRCL_OUT_DQ', 'H1:CAL-DELTAL_EXTERNAL_DQ(REF7)', 'H1:CAL-DELTAL_EXTERNAL_DQ', PRCL_dtt)

prcl_to_darm_projection = darm_cal * excess_power_projection(freq, 
                                                  freq, 
                                                  'H1:LSC-PRCL_OUT_DQ(REF8)', #witness quiet
                                                  'H1:LSC-PRCL_OUT_DQ', #witness excitation
                                                  'H1:CAL-DELTAL_EXTERNAL_DQ(REF7)', #target quiet
                                                  'H1:CAL-DELTAL_EXTERNAL_DQ', #target excitation
                                                  PRCL_dtt) 

quiet_darm_asd = darm_cal * PRCL_dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF7)').asd

#now, make the projection from PRCL to SRCL
prcl_to_srcl_projection = excess_power_projection(freq,
                                               freq, 
                                               'H1:LSC-PRCL_OUT_DQ(REF8)', #witness quiet
                                               'H1:LSC-PRCL_OUT_DQ',  #witness excitation
                                               'H1:LSC-SRCL_OUT_DQ(REF15)', #target quiet
                                               'H1:LSC-SRCL_OUT_DQ',  #target excitation
                                               PRCL_dtt)

quiet_srcl_ctrl_asd = PRCL_dtt.asd('H1:LSC-SRCL_OUT_DQ(REF15)').asd

#and use SRCL to darm to make a PRCL-SRCL-DARM projection
SRCL_dtt = dttxml.DiagAccess("SRCL_excitation.xml")
srcl_to_darm_coupling = excess_power_coupling(freq,
                                               freq, 
                                               'H1:LSC-SRCL_OUT_DQ(REF7)', #witness quiet
                                               'H1:LSC-SRCL_OUT_DQ',  #witness excitation
                                               'H1:CAL-DELTAL_EXTERNAL_DQ(REF8)', #target quiet
                                               'H1:CAL-DELTAL_EXTERNAL_DQ',  #target excitation
                                               SRCL_dtt)

#PRCL to SRCL to DARM projection (PRCL coupling through SRCL)
prcl_to_darm_through_srcl = darm_cal * prcl_to_srcl_projection * srcl_to_darm_coupling

srcl_to_darm_projection = darm_cal * excess_power_projection(freq,
                                               freq, 
                                               'H1:LSC-SRCL_OUT_DQ(REF7)', #witness quiet
                                               'H1:LSC-SRCL_OUT_DQ',  #witness excitation
                                               'H1:CAL-DELTAL_EXTERNAL_DQ(REF8)', #target quiet
                                               'H1:CAL-DELTAL_EXTERNAL_DQ',  #target excitation
                                               SRCL_dtt)

### now, PRCL to MICH to DARM
prcl_to_mich_projection = excess_power_projection(freq,
                                               freq, 
                                               'H1:LSC-PRCL_OUT_DQ(REF8)', #witness quiet
                                               'H1:LSC-PRCL_OUT_DQ',  #witness excitation
                                               'H1:CAL-CS_MICH_DQ(REF16)', #target quiet
                                               'H1:CAL-CS_MICH_DQ',  #target excitation
                                               PRCL_dtt)

#MICH TO DARM coupling
MICH_dtt = dttxml.DiagAccess("MICH_excitation.xml")

mich_to_darm_projection = darm_cal * excess_power_projection(freq,
                                               freq, 
                                               'H1:LSC-MICH_OUT_DQ(REF7)', #witness quiet
                                               'H1:LSC-MICH_OUT_DQ',  #witness excitation
                                               'H1:CAL-DELTAL_EXTERNAL_DQ(REF6)', #target quiet
                                               'H1:CAL-DELTAL_EXTERNAL_DQ',  #target excitation
                                               MICH_dtt)

mich_to_darm_coupling = excess_power_coupling(freq,
                                               freq, 
                                               'H1:LSC-MICH_OUT_DQ(REF7)', #witness quiet
                                               'H1:LSC-MICH_OUT_DQ',  #witness excitation
                                               'H1:CAL-DELTAL_EXTERNAL_DQ(REF6)', #target quiet
                                               'H1:CAL-DELTAL_EXTERNAL_DQ',  #target excitation
                                               MICH_dtt)
prcl_to_darm_through_mich = darm_cal * prcl_to_mich_projection * mich_to_darm_coupling

lsc_total_projection = np.sqrt(mich_to_darm_projection**2 + srcl_to_darm_projection**2 + prcl_to_darm_projection**2)

fig, ax = plt.subplots()
ax.loglog(freq, quiet_darm_asd, label = 'Quiet DARM')
ax.loglog(freq,prcl_to_darm_projection, label = 'PRCL to DARM projection' )
ax.loglog(freq,srcl_to_darm_projection, label = 'SRCL to DARM projection' )
ax.loglog(freq,mich_to_darm_projection, label = 'MICH to DARM projection' )
ax.loglog(freq,lsc_total_projection, label = 'LSC to DARM total' )
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('displacement [m/rt Hz]')
ax.set_xlim([10, 7e3])
ax.set_ylim([3e-21, 3e-17])
ax.grid(True)
ax.legend()
plt.savefig('LSC_projections_to_darm_May9_2024.png')

fig, ax = plt.subplots()
ax.loglog(freq, quiet_darm_asd, label = 'Quiet DARM')
ax.loglog(freq,prcl_to_darm_projection, label = 'PRCL to DARM projection' )
ax.loglog(freq,prcl_to_darm_through_mich, label = 'PRCL to DARM through MICH projection' )
ax.loglog(freq,prcl_to_darm_through_srcl, label = 'PRCL to DARM through SRCL projection' )
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('displacement [m/rt Hz]')
ax.set_xlim([10, 7e3])
ax.set_ylim([3e-21, 3e-17])
ax.grid(True)
ax.legend()
plt.savefig('PRCL_projections_to_darm_May9_2024.png')
 
fig, ax = plt.subplots()
ax.loglog(freq, quiet_srcl_ctrl_asd, label = 'Quiet SRCL out')
ax.loglog(freq,prcl_to_srcl_projection, label = 'PRCL to SRCL' )
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('displacement [cnts/rt Hz]')
ax.set_xlim([10, 7e3])
#ax.set_ylim([3e-21, 3e-17])
ax.grid(True)
ax.legend()

