% ============================== % 
% This file studies the colored fading channel model. 
% We need to find a good estimate for this channel. 
% 
% ============================== % 
clear; close all; clc;

addpath('algs/'); 
addpath('utils/');
fc      = 3.5e9; 
c       = physconst('lightspeed'); 
lambda  = c/fc; 
k0      = 2*pi/lambda; 


%% Get channel and precode. 

rng(0); 
L           = 3; 
N_test      = 40; 
P           = 10;
sigma2      = 1; 
Ncombiner   = 40; 

SIMO_params = setup("SIMOULAwithFixedK", fc, L);           % MIMO setup
SIMO_params.Ncombiner = Ncombiner; 
SIMO_params.P = P; 
SIMO_params.sigma2 = sigma2; 
SIMO_params.normalizeChannel = true; % normalize the average power of each row of the channel matrix to 1. 

NRx = SIMO_params.RxArray.N; 

SNR_arr = (0:4:20).'; 
N_SNR = length(SNR_arr); 
NMSE = zeros(N_SNR, 6); 
fprintf('Simulation with %d Pilots\n', SIMO_params.Ncombiner); 
fprintf('PSWF w/ Stat CSI  | PSWF w/o Stat CSI  |  RandCombiner  w/o Stat CSI  |  AMP  |  SBAR   |  Adaptive PSWF w/o Stat CSI\n'); 

for i_snr = 1:N_SNR
    rng(0); 
    P = db2pow(SNR_arr(i_snr)); 
    inner_NMSE = zeros(N_test, 6);

    for i_test = 1:N_test
        [H, Corr] = getChannel(SIMO_params, "ColoredAngularCorrelation"); 
    
        % fprintf('Channel Fnorm2 = %.3f / normalized = %.3f NR = %d \n', norm(H, 'fro')^2, norm(H, 'fro')^2/(SIMO_params.RxArray.N), SIMO_params.RxArray.N); 
%         figure(1); 
%         plotSIMOkRegion(SIMO_params, H);

        H = H * sqrt(NRx); % Ensure each Rx antenna receives a unit average signal power. 
        % y = sqrt(P)*H + sqrt(sigma2)*(randn([NRx, 1]) + 1j*randn([NRx, 1]))/sqrt(2); 
        
        nvec = sqrt(sigma2)*(randn([Ncombiner, 1]) + 1j*randn([Ncombiner, 1]))/sqrt(2); 

        % [METHOD] PSWF combiner with Statistical CSI
        Wpswf = getCombiner(SIMO_params, SIMO_params.RxArray.RegionsRx, Ncombiner, "PSWF", 0.1); 
        y_tilde = Wpswf'*sqrt(P)*H + nvec; 
        % h_hat = (sigma2*inv(NRx*Corr.CovH) + P*(Wpswf*Wpswf'))\(sqrt(P)*Wpswf*y_tilde); 
        h_hat = (NRx*Corr.CovH)*sqrt(P)*Wpswf*((P*Wpswf'*(NRx*Corr.CovH)*Wpswf + sigma2*eye(Ncombiner))\y_tilde); 
        inner_NMSE(i_test, 1) = norm(H - h_hat)^2/norm(H)^2; 

        % [METHOD] PSWF combiner w/o Statistical CSI
        h_hat = (sigma2*eye(NRx) + P*(Wpswf*Wpswf'))\(sqrt(P)*Wpswf*y_tilde); 
        inner_NMSE(i_test, 2) = norm(H - h_hat)^2/norm(H)^2; 
    
        % [METHOD] Random Combiner
        Wrand = getCombiner(SIMO_params, SIMO_params.RxArray.RegionsRx, Ncombiner, "Random", 0.1); 
        y_tilde_rand = Wrand'*sqrt(P)*H + nvec; 
        h_hat = (sigma2*eye(NRx) + P*(Wrand*Wrand'))\(sqrt(P)*Wrand*y_tilde_rand); 
        inner_NMSE(i_test, 3) = norm(H - h_hat)^2/norm(H)^2; 
    
        
        % [METHOD] Compressed Sensing - AMP
        A = Wrand'*dftmtx(NRx)/sqrt(NRx); 
        x_hat = camp(y_tilde_rand/sqrt(P), A, 0.90); 
        h_hat = (dftmtx(NRx)/sqrt(NRx))*x_hat; 
        inner_NMSE(i_test, 4) = norm(H - h_hat)^2/norm(H)^2; 

        % [METHOD] Adaptive Bayesian Estimator
        inner_NMSE(i_test, 5) = AdaptivePSWFestimator(H, P, nvec, SIMO_params, "Bayesian"); 


        % [METHOD] Adaptive PSWF
        inner_NMSE(i_test, 6) = AdaptivePSWFestimator(H, P, nvec, SIMO_params, "BT-PSWF"); 

    end
    
    NMSE(i_snr, :) = (mean(inner_NMSE, 1));

    fprintf('SNR = %.3f dB | Avg NMSE = %.3f dB, %.3f dB, %.3f dB, %.3f dB, %.3f dB, %.3f dB\n', SNR_arr(i_snr), pow2db(NMSE(i_snr, 1)), pow2db(NMSE(i_snr, 2)), pow2db(NMSE(i_snr, 3)), pow2db(NMSE(i_snr, 4)), pow2db(NMSE(i_snr, 5)), pow2db(NMSE(i_snr, 6))); 
end

%% Visualization
set(0,'DefaultLineMarkerSize',  6);
set(0,'DefaultTextFontSize',    14);
set(0,'DefaultAxesFontSize',    12);
set(0,'DefaultLineLineWidth',   1.4);
set(0,'defaultfigurecolor',     'w');

figure('Color', [1, 1, 1]); 

plot(SNR_arr, NMSE(:,3), 'Color','b', 'Marker','v'); hold on; grid on; box on;  
plot(SNR_arr, NMSE(:,4), 'Color', [0 0.6 0.6] , 'Marker','<'); 
plot(SNR_arr, NMSE(:,5), 'Color', [0.8, 0.2, 0.8], 'Marker','o'); 
plot(SNR_arr, NMSE(:,6), 'Color', [1, 0, 0], 'Marker', 'square'); 
plot(SNR_arr, NMSE(:,2), 'Color',[0.7, 0, 0], 'LineStyle','-', 'Marker', 'pentagram'); 
plot(SNR_arr, NMSE(:,1), 'Color','k', 'LineStyle','--'); 

legend({'Random Comb. MMSE', 'Random Comb. AMP', 'SBAR', 'Adaptive PSWF', 'PSWF Comb. w/o Stat CSI', 'PSWF Comb. w/ Stat CSI (Oracle)'}, 'FontSize', 9, 'Location','southwest'); 
xlabel('SNR (dB)'); 
ylabel('NMSE'); 
set(gca, 'yscale', 'log');

exportgraphics(gcf, 'results/PerformanceWithSNR.pdf', 'ContentType','vector'); 

%% Utils
function plotSIMOkRegion(SIMO_params, h)
    figure(); 
    dat =abs(h'*SIMO_params.RxArray.dict).^2.';
    dat = dat/norm(dat); 

    M = max(dat); 
    plot(SIMO_params.nkGrids(:,1), dat); hold on; 

    for idx = 1:length(SIMO_params.RxArray.RegionsRx)
        x = SIMO_params.RxArray.RegionsRx{idx}(1) - SIMO_params.RxArray.RegionsRx{idx}(3); 
        line([x, x], [0, M], 'LineStyle', '--', 'Color', 'r');

        y = SIMO_params.RxArray.RegionsRx{idx}(1) + SIMO_params.RxArray.RegionsRx{idx}(3); 
        line([y, y], [0, M], 'LineStyle', '--', 'Color', 'b');

        fprintf('Recv nk region %d: [%.2f ~ %.2f] \n', idx, x, y); 
    end

    xlim([-1.1, 1.1]); 
    xlabel('Normalized wavenumber k'); 
    ylabel('Relative Intensity'); 
    grid on; box on; 

end


