#a script to use a set of ifo measurements with different NLGs to fit loss 

import os, sys, warnings
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.pyplot as plt
import h5py
#import pickle
import yaml

#ligo packages to import
import dttxml 
import gwpy.timeseries
import gwinc

#modules to import
import logbinning
import utils

fflog = np.logspace(np.log10(1), np.log10(7000), 1000)

'''
HD_dtt = dttxml.DiagAccess("dtt_files/20250109SQZHD.xml")  #data from alog 82202
freq = HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF0)').FHz
dark_noise = logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF0)').asd)
shot_noise = logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF1)').asd)


meas_dict = {'nlg11':{'nlg':11.2,
                      'opo_trans':80,  #80uW OPO trans set point
                         'asqz':logbinning.logbin_asd(fflog, freq, HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF2)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF3)').asd),
                          },
            'nlg14':{'nlg':14.3,
                    'opo_trans':120,
                    'asqz':logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF5)').asd),
                    'sqz':logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF4)').asd)
                          },
            'nlg16':{'nlg':16,
                    'opo_trans':140,
                    'asqz':logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF6)').asd),
                    'sqz':logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF8)').asd)
                          },
            }
'''
dtt = dttxml.DiagAccess("dtt_files/20250314_NLG.xml")  #data from alog 83370
#starting without any subtraction, then we can estaimte where the technical noise is.  If we want later we can come back and do subtraction and see if it agrees with what we estimate from the NLG scan
freq = dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF0)').FHz
#dark_noise = logbinning.logbin_asd(fflog, freq,HD_dtt.asd('H1:SQZ-HD_DIFF_DC_OUT_DQ(REF0)').asd)


meas_dict = {'opo_trans_120':{'nlg':61,
                      'nlg_max_min':62.7,
                      'opo_trans':120,  #80uW OPO trans set point
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF1)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF3)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF4)').asd),
                          },
          'opo_trans_110':{'nlg':35.2,
                      'nlg_max_min':35.4,
                      'opo_trans':110,  
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF5)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF6)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF7)').asd),
                          },
            'opo_trans_90':{'nlg':16.2,
                      'nlg_max_min':16.8,
                      'opo_trans':90,  
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF8)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF9)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF10)').asd),
                          },
            'opo_trans_70':{'nlg':8.9,
                      'nlg_max_min':9.2,
                      'opo_trans':70,  
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF11)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF12)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF13)').asd),
                          },
            'opo_trans_40':{'nlg':4.0,
                      'nlg_max_min':4.2,
                      'opo_trans':40,  
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF14)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF15)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF16)').asd),
                          },
             'opo_trans_25':{'nlg':2.7,
                      'nlg_max_min':2.8,
                      'opo_trans':25,  
                         'asqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF17)').asd),
                          'sqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF18)').asd),
                          'msqz':logbinning.logbin_asd(fflog, freq, dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF19)').asd),
                          },
            }
#meas_dict = meas_dict_Feb25

#subtract dark noise from all asds
#shot_noise_subtracted = np.sqrt(shot_noise**2 - dark_noise**2)
#for key in meas_dict.keys():
#    meas_dict[key]['asqz_sub'] = np.sqrt(meas_dict[key]['asqz']**2 - dark_noise**2)
#    meas_dict[key]['sqz_sub'] = np.sqrt(meas_dict[key]['sqz']**2 - dark_noise**2)
#    meas_dict[key]['msqz_sub'] = np.sqrt(meas_dict[key]['msqz']**2 - dark_noise**2)

#apply calibraion
deltal_calibration = utils.import_dtt_calib(fflog)
nosqz = deltal_calibration * logbinning.logbin_asd(fflog, freq,dtt.asd('H1:CAL-DELTAL_EXTERNAL_DQ(REF0)').asd)
for key in meas_dict.keys():
    meas_dict[key]['asqz'] *= deltal_calibration 
    meas_dict[key]['sqz'] *= deltal_calibration
    meas_dict[key]['msqz'] *= deltal_calibration
    meas_dict[key]['sqz_dB_all_freq'] = utils.dB(meas_dict[key]['sqz']/nosqz)
    meas_dict[key]['asqz_dB_all_freq'] = utils.dB(meas_dict[key]['asqz']/nosqz)
    meas_dict[key]['msqz_dB_all_freq'] = utils.dB(meas_dict[key]['msqz']/nosqz)

fig1, [ax1,ax2,ax3] = plt.subplots(3,1, figsize=[10,12])
ax1.loglog(fflog, nosqz, label = 'no squeezing')
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax1.loglog(fflog, meas_dict[key]['sqz'], label = name)
    #plt.loglog(fflog, meas_dict[key]['asqz'], label = name + ' anti-sqz')
    #plt.loglog(fflog, meas_dict[key]['msqz'], linestyle = '-', label = name + ' mean sqz')
ax1.legend()
ax1.set_title(' Frequency Independent Squeezing')
ax2.loglog(fflog, nosqz, label = 'no squeezing')
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax2.loglog(fflog, meas_dict[key]['msqz'], label = name)
ax2.set_title('Mean squeezing (LO loop unlocked)')
#ax2.legend()
ax3.loglog(fflog, nosqz, label = 'no squeezing')
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax3.loglog(fflog, meas_dict[key]['asqz'], label = name )
#ax3.legend()
ax3.set_title('Anti squeezing')
ax3.set_xlabel('Frequency [Hz]')

for ax in [ax1,ax2, ax3]:
    ax.set_xlim([10,6000])
    ax.set_ylabel('Displaement [m/rt Hz]')
    ax.set_ylim([1e-20,1e-17])
plot_file_name = 'March14_nlg_sweep_spectra'
utils.save_plots(fig1, 'plots', plot_file_name)


#choose some frequency bands to study
#find medians of the ASDs in your chosen frequency range, do this before taking logs
fmin = 2060# clean, no technical noise, seems like high nlg has more sqz than lower ones
fmax = 2200
[fmin_indx,fmax_indx] = utils.find_freq_indeces(fflog, fmin, fmax)
fmid = (fmin+fmax)/2

#fmin2 = 1630
#fmax2 = 1700
fmin2 = 1230
fmax2 = 1270
[fmin2_indx,fmax2_indx] = utils.find_freq_indeces(fflog, fmin2, fmax2)
fmid2 = (fmin2+fmax2)/2

nosqz_median = np.median(nosqz[fmin_indx:fmax_indx])
nosqz_median2 = np.median(nosqz[fmin2_indx:fmax2_indx])

for key in meas_dict.keys():
    #take the median of your sqz and asqz in you 
    meas_dict[key]['sqz_dB'] = utils.dB(np.median(meas_dict[key]['sqz'][fmin_indx:fmax_indx]/nosqz[fmin_indx:fmax_indx]))
    meas_dict[key]['asqz_dB'] = utils.dB(np.median(meas_dict[key]['asqz'][fmin_indx:fmax_indx]/nosqz[fmin_indx:fmax_indx]))
    meas_dict[key]['msqz_dB'] = utils.dB(np.median(meas_dict[key]['msqz'][fmin_indx:fmax_indx]/nosqz[fmin_indx:fmax_indx]))
    meas_dict[key]['sqz_dB2'] = utils.dB(np.median(meas_dict[key]['sqz'][fmin2_indx:fmax2_indx]/nosqz[fmin2_indx:fmax2_indx]))
    meas_dict[key]['asqz_dB2'] = utils.dB(np.median(meas_dict[key]['asqz'][fmin2_indx:fmax2_indx]/nosqz[fmin2_indx:fmax2_indx]))
    meas_dict[key]['msqz_dB2'] = utils.dB(np.median(meas_dict[key]['msqz'][fmin2_indx:fmax2_indx]/nosqz[fmin2_indx:fmax2_indx]))

#there has to be a better way to do this
nlgs = []
nlgs_max_min = []
opo_trans = []
sqz_dB = []
asqz_dB = []
msqz_dB = []
sqz_dB2 = []
asqz_dB2 = []
msqz_dB2 = []
for key in meas_dict.keys():
    nlgs = np.append(nlgs, meas_dict[key]['nlg'])
    nlgs_max_min = np.append(nlgs_max_min, meas_dict[key]['nlg_max_min'])
    opo_trans = np.append(opo_trans, meas_dict[key]['opo_trans'])
    sqz_dB = np.append(sqz_dB, meas_dict[key]['sqz_dB'])
    asqz_dB = np.append(asqz_dB, meas_dict[key]['asqz_dB'])
    msqz_dB = np.append(msqz_dB, meas_dict[key]['msqz_dB'])
    sqz_dB2 = np.append(sqz_dB2, meas_dict[key]['sqz_dB2'])
    asqz_dB2 = np.append(asqz_dB2, meas_dict[key]['asqz_dB2'])
    msqz_dB2 = np.append(msqz_dB2, meas_dict[key]['msqz_dB2'])

#make a plot in dB
fig1, [ax1,ax2, ax3] = plt.subplots(3,1, figsize=[10,8])
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax1.semilogx(fflog, meas_dict[key]['sqz_dB_all_freq'], label = name + ' sqz')
ax1.semilogx(fmid*np.ones_like(sqz_dB),sqz_dB, marker = '+' )
ax1.semilogx(fmid2*np.ones_like(sqz_dB2),sqz_dB2, marker = 'o' )
ax1.legend()
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax2.semilogx(fflog, meas_dict[key]['msqz_dB_all_freq'], label = name + ' msqz')
ax2.semilogx(fmid*np.ones_like(sqz_dB),msqz_dB, marker = '+' )
ax2.semilogx(fmid2*np.ones_like(sqz_dB2),msqz_dB2, marker = 'o' )
ax2.legend()
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax3.semilogx(fflog, meas_dict[key]['asqz_dB_all_freq'], label = name + ' asqz')
ax3.semilogx(fmid*np.ones_like(sqz_dB),asqz_dB, marker = '+' )
ax3.semilogx(fmid2*np.ones_like(sqz_dB2),asqz_dB2, marker = 'o' )
ax3.legend()

for ax in [ax1,ax2, ax3]:
    ax.set_xlim([10,6000])
    ax.set_xlabel('Frequency [Hz]')
    ax.set_ylim([-5,25])
plot_file_name = 'March14_nlg_sweep'

#change your approach to fitting, instead of fitting everything to sqz/asqz, try threshold to nlg, then losses for mean sqz or asqz....
p_thresh_fit_normal = utils.fit_threshold([opo_trans, nlgs])
p_thresh_fit_maxMin = utils.fit_threshold([opo_trans, nlgs_max_min])

#now fit totally efficiency using mean sqz
eta = utils.fit_eta_to_msqz_dBs([opo_trans, msqz_dB, p_thresh_fit_normal])
eta2 = utils.fit_eta_to_msqz_dBs([opo_trans, msqz_dB2, p_thresh_fit_normal])

for key in meas_dict.keys():
    meas_dict[key]['eta'] = utils.eta_from_msqz(meas_dict[key]['msqz_dB_all_freq'], meas_dict[key]['opo_trans'], p_thresh_fit_normal)
    meas_dict[key]['eta2'] = utils.eta_from_msqz(meas_dict[key]['msqz_dB_all_freq'], meas_dict[key]['opo_trans'], p_thresh_fit_maxMin)

fig, ax = plt.subplots(1,1,figsize=[8,8])
for key in meas_dict.keys():
    name = 'nlg = ' + str(meas_dict[key]['nlg'] )
    ax.semilogx(fflog, meas_dict[key]['eta2'], label = name )
ax.set_ylim([0,1])
ax.set_xlim([10,6000])
ax.set_ylabel('Total efficency infered from mean sqz')
ax.set_xlabel('Frequency [Hz]')
ax.legend()
#ax.grid(which='both',axis='both')
plot_file_name = 'Eta_infered_from_mean_sqz'
utils.save_plots(fig, 'plots', plot_file_name)

opo_trans_vector = np.linspace(0,500, num=600)

fit_dBs_args = [opo_trans, sqz_dB, asqz_dB, msqz_dB] #arguments to pass to fitting function
fit_params = utils.fit_sqz_asqz_msqz_dBs(fit_dBs_args)
fit_dBs_args2 = [opo_trans, sqz_dB2, asqz_dB2, msqz_dB2] #arguments to pass to fitting function for second frequency band
fit_params2 = utils.fit_sqz_asqz_msqz_dBs(fit_dBs_args2)
p_thresh_fit = fit_params[0]
#params = [156, 0.456, 0.03]
[model_sqz_dB, model_asqz_dB, model_msqz_dB] = utils.dB(utils.sqz_asqz_msqz_from_params(fit_params, opo_trans_vector), power=True)
[model_sqz_dB2, model_asqz_dB2, model_msqz_dB2] = utils.dB(utils.sqz_asqz_msqz_from_params(fit_params2, opo_trans_vector), power=True)

[expected_model_sqz_dB, expected_model_asqz_dB, expected_model_msqz_dB] = utils.dB(utils.sqz_asqz_msqz_from_params([158, 0.83*0.855, 0], opo_trans_vector), power=True)


fig1, ax1= plt.subplots(1,1, figsize=[8,8])
ax1.semilogy(opo_trans, nlgs, label = 'measured amp/unamp', marker = 'o', linestyle='None')
ax1.semilogy(opo_trans, nlgs_max_min, label = 'measured NLG amplified/deamplified method', marker = '+', linestyle='None')
ax1.semilogy(opo_trans_vector, utils.nlg_from_power(opo_trans_vector, p_thresh_fit), label = f'{p_thresh_fit:0.1f} uW, sqz and asqz fit')
ax1.semilogy(opo_trans_vector, utils.nlg_from_power(opo_trans_vector, p_thresh_fit_normal), label = f'{p_thresh_fit_normal[0]:0.1f} uW, amplified/unamplified fit')
ax1.semilogy(opo_trans_vector, utils.nlg_from_power(opo_trans_vector, p_thresh_fit_maxMin), label = f'{p_thresh_fit_maxMin[0]:0.1f}  amp + deamp fit')
ax1.set_xlabel('OPO trans power [uW]')
ax1.set_ylabel('non linear gain')
ax1.set_xlim([0,150])
ax1.set_ylim([0,80])
ax1.legend()
date = 'March 14'
plot_file_name = date + 'fitting_opo_thresh'
utils.save_plots(fig1, 'plots', plot_file_name)

fig1, ax2= plt.subplots(1,1, figsize=[12,8])
ax2.plot(opo_trans, sqz_dB, label = f'measured sqz {fmid:0.0f} Hz', marker = 'o', linestyle='None')
ax2.plot(opo_trans, asqz_dB, label = f'measured anti-sqz {fmid:0.0f} Hz', marker = 'o', linestyle='None')
ax2.plot(opo_trans, msqz_dB, label = f'measured mean sqz {fmid:0.0f} Hz', marker = 'o', linestyle='None')
ax2.plot(opo_trans, sqz_dB2, label = f'measured sqz {fmid2:0.0f} Hz', marker = '+', linestyle='None')
ax2.plot(opo_trans, asqz_dB2, label = f'measured anti-sqz {fmid2:0.0f} Hz', marker = '+', linestyle='None')
ax2.plot(opo_trans, msqz_dB2, label = f'measured mean sqz {fmid2:0.0f} Hz', marker = '+', linestyle='None')
ax2.plot(opo_trans_vector, model_sqz_dB, label = f'sqz model {fmid:0.0f} Hz')
ax2.plot(opo_trans_vector, model_asqz_dB, label = f'asqz model {fmid:0.0f} Hz')
ax2.plot(opo_trans_vector, model_msqz_dB, label = f'mean sqz model {fmid:0.0f} Hz')
ax2.plot(opo_trans_vector, model_sqz_dB2, label = f'sqz model {fmid2:0.0f} Hz', linestyle = '--')
ax2.plot(opo_trans_vector, model_asqz_dB2, label = f'asqz model {fmid2:0.0f} Hz', linestyle = '--')
ax2.plot(opo_trans_vector, model_msqz_dB2, label = f'mean sqz model {fmid2:0.0f} Hz', linestyle = '--')
ax2.plot(opo_trans_vector, expected_model_sqz_dB, label = f'sqz 83% efficiency', linestyle = '--')
ax2.set_xlabel('OPO trans power [uW]')
ax2.set_ylabel('measured QN [dB]')
ax2.set_xlim([0,150])
ax2.set_ylim([-10, 30])
ax2.legend()
title_string = f"{fmin} to {fmax} Hz fit:"+\
     f" threshold {fit_params[0]:0.0f} uW OPO trans, total efficiency {fit_params[1]:0.2f} and phase noise {1e3*fit_params[2]:0.0f} mrad \n" +\
     f"{fmin2} to {fmax2} Hz fit: "+\
     f"threshold {fit_params2[0]:0.0f} uW OPO trans, total efficiency {fit_params2[1]:0.2f} and phase noise {1e3*fit_params2[2]:0.0f} mrad" 

ax2.set_title(title_string)

plot_file_name = 'Macrh14_nlg_sweep_measurement_2k_and1p3k'

utils.save_plots(fig1, 'plots', plot_file_name)


#do some loss budgerting
opo_esc= 0.985
SFIs = 0.99**3
FC_WFS = 0.99
ZM456 = 0.99
OFI = 0.99*0.995  
SRC_loss = 0.99
OM1 = 0.9993
OM3 = 0.985
OMC_QPD = 0.9904
OMC = 0.956
QE = 0.98
extra_HAM7 = 0.94  #alog 83070

eta = fit_params[1]
eta_expected = opo_esc * SFIs * FC_WFS * ZM456 * OFI * SRC_loss * OM1 * OM3 * OMC_QPD * OMC * QE
eta_unexpected = eta / eta_expected

extra_ifo = eta_unexpected / extra_HAM7