% ============================== % 
% This file studies the colored fading channel model. 
% We need to find a good estimate for this channel. 
% 
% ============================== % 
clear all; close all; clc;

addpath('algs/'); 
addpath('utils/');
fc          = 3.5e9; 
c           = physconst('lightspeed'); 
lambda      = c/fc; 
k0          = 2*pi/lambda; 


%% Simulate channel estimators. 

rng(0); 
L                       = 1;
% L           = 3; 
N_test                  = 100; 
P                       = 100;
sigma2                  = 1; 
Ncombiner               = 120; 
channelGenerationMethod = "CDL"; 

% Setup bandwidths. 
reg.Gamma       = 0.05; 
RegionsTx       = cell(1,1);
RegionsTx{1}    = [0, 0, 0.15, reg.Gamma];   
RegionsRx       = cell(1,1); 
RegionsRx{1}    = [0, 0, 0.15, reg.Gamma];   
reg.RegionsTx   = RegionsTx; 
reg.RegionsRx   = RegionsRx; 

% MIMO = setup("SIMOULAwithFixedK", fc, L, reg);     
MIMO = setup("SymmetricULAwithFixedK", fc, L, reg);  

MIMO.P                  = P; 
MIMO.sigma2             = sigma2; 
MIMO.normalizeChannel   = true;         % Normalize the average power of each row of the channel matrix to 1. 
MIMO.forceGetCovariance = true;
MIMO.Ncombiner          = Ncombiner; 

NRx                     = MIMO.RxArray.N; 

SNR_arr         = (-10:2:20).'; 
N_SNR           = length(SNR_arr); 
NMSE            = zeros(N_SNR, 7); 
StdNMSE         = zeros(N_SNR, 7); 
StdNMSEdB       = zeros(N_SNR, 7); 

fprintf('Simulation with Ncombiner = %d and Nantenna = %d\n', Ncombiner, NRx); 
fprintf('Channel Generation Method = %s\n', channelGenerationMethod); 
fprintf('+==========+========================================================================================================+\n');
fprintf('|          |                                                 NMSE(dB)                                               |\n');
fprintf('|  SNR(dB) +--------------------------------------------------------------------------------------------------------+\n');
fprintf('|          | PSWF w/  Stat | PSWF w/o Stat |  RandComb  |  AMP  |   SBAR   | BT-PSWF | BWest-PSWF | Time Elapsed (s)|\n');
fprintf('+----------+---------------+---------------+------------+-------+----------+---------+------------+-----------------+\n');

for i_snr = 1:N_SNR
    P               = db2pow(SNR_arr(i_snr)); 
    MIMO.P          = P; 
    
    inner_NMSE1     = zeros(N_test, 1);
    inner_NMSE2     = zeros(N_test, 1);    
    inner_NMSE3     = zeros(N_test, 1);
    inner_NMSE4     = zeros(N_test, 1);
    inner_NMSE5     = zeros(N_test, 1);
    inner_NMSE6     = zeros(N_test, 1);
    inner_NMSE7     = zeros(N_test, 1);

    tic; 
    for i_test = 1:N_test
        rng(i_test);

        % [H, Corr] = getChannel(MIMO, "ColoredAngularCorrelation"); 
        [H, Corr] = getChannel(MIMO, channelGenerationMethod); 

        if MIMO.isSIMO
            H                       = H * sqrt(NRx); % Ensure each Rx antenna receives a unit average signal power. 
            % y = sqrt(P)*H + sqrt(sigma2)*(randn([NRx, 1]) + 1j*randn([NRx, 1]))/sqrt(2); 
            
            nvec                    = sqrt(sigma2)*(randn([Ncombiner, 1]) + 1j*randn([Ncombiner, 1]))/sqrt(2); 
    
            % [METHOD 1] PSWF combiner with Statistical CSI
            Wpswf                   = getCombiner(MIMO, MIMO.RxArray.RegionsRx, Ncombiner, "PSWF", 0.1); 
            y_tilde                 = Wpswf'*sqrt(P)*H + nvec; 
            % h_hat = (sigma2*inv(NRx*Corr.CovH) + P*(Wpswf*Wpswf'))\(sqrt(P)*Wpswf*y_tilde); 
            h_hat                   = (NRx*Corr.CovH)*sqrt(P)*Wpswf*((P*Wpswf'*(NRx*Corr.CovH)*Wpswf + sigma2*eye(Ncombiner))\y_tilde); 
            inner_NMSE1(i_test)     = norm(H - h_hat)^2/norm(H)^2; 
    
            % [METHOD 2] PSWF combiner without statistical CSI 
            h_hat                   = (sigma2*eye(NRx) + P*(Wpswf*Wpswf'))\(sqrt(P)*Wpswf*y_tilde); 
            inner_NMSE2(i_test)     = norm(H - h_hat)^2/norm(H)^2; 
            
            % [METHOD 3] Random Combiner without statistical CSI
            Wrand                   = getCombiner(MIMO, MIMO.RxArray.RegionsRx, Ncombiner, "Random", 0.1); 
            y_tilde_rand            = Wrand'*sqrt(P)*H + nvec; 
            h_hat                   = (sigma2*eye(NRx) + P*(Wrand*Wrand'))\(sqrt(P)*Wrand*y_tilde_rand); 
            inner_NMSE3(i_test)     = norm(H - h_hat)^2/norm(H)^2; 
            
            % [METHOD 4] Compressed Sensing - AMP
            A                       = Wrand'*dftmtx(NRx)/sqrt(NRx); 
            x_hat                   = camp(y_tilde_rand/sqrt(P), A, 0.90); 
            h_hat                   = (dftmtx(NRx)/sqrt(NRx))*x_hat; 
            inner_NMSE4(i_test)     = norm(H - h_hat)^2/norm(H)^2; 
    
            % [METHOD 5] Adaptive Bayesian Estimator
            inner_NMSE5(i_test)     = AdaptivePSWFestimator(H, P, nvec, MIMO, "Bayesian"); 
    
            % [METHOD 6] Adaptive BeamTraining-PSWF (BT-PSWF method)
            inner_NMSE6(i_test)     = AdaptivePSWFestimator(H, P, nvec, MIMO, "BT-PSWF"); 
    
            % [METHOD 7] Bandwidth-estimating PSWF
            inner_NMSE7(i_test)     = AdaptivePSWFestimator(H, P, nvec, MIMO, "BWest-PSWF"); 
        else
            
            NTx                     = MIMO.TxArray.N; 
            Nh                      = NRx*NTx; 
            H                       = H * sqrt(Nh);

            nvec                    = sqrt(sigma2)*(randn([Ncombiner, 1]) + 1j*randn([Ncombiner, 1]))/sqrt(2); 
            [W, V]                  = getMIMOCombinerPrecoder(MIMO, MIMO.RxArray.RegionsRx, MIMO.TxArray.RegionsTx, Ncombiner, "PSWF", 0.1); 
            
            y                       = zeros(Ncombiner, 1); 
            A                       = zeros(Ncombiner, Nh); 
            for idx = 1:Ncombiner
                y(idx)              = sqrt(P)*(W(:,idx)')*H*V(:,idx) + nvec(idx); 
                A(idx, :)           = kron(V(:,idx).', W(:,idx)'); 
            end

            % Model: y = sqrt(P)*A*h + noise, h = vec(H). 
            % [Method 1]
            h_hat                   = sqrt(P)*(Nh*Corr.CovH*A')*((P*A*(Nh*Corr.CovH)*A' + sigma2*eye(Ncombiner))\y);
            H_hat                   = reshape(h_hat, [NRx, NTx]); 
            inner_NMSE1(i_test)     = norm(H_hat - H, 'fro')^2/norm(H, 'fro')^2; 
            
            % [Method 2]
            h_hat                   = (sigma2*eye(Nh) + P*(A'*A))\(sqrt(P)*A'*y); 
            H_hat                   = reshape(h_hat, [NRx, NTx]); 
            inner_NMSE2(i_test)     = norm(H_hat - H, 'fro')^2/norm(H, 'fro')^2; 

            % [Method 3] 
            [Wrand, Vrand]          = getMIMOCombinerPrecoder(MIMO, MIMO.RxArray.RegionsRx, MIMO.TxArray.RegionsTx, Ncombiner, "Random", 0.1); 
            Arand                   = zeros(Ncombiner, Nh); 
            yrand                   = zeros(Ncombiner, 1); 
            for idx = 1:Ncombiner
                yrand(idx)          = sqrt(P)*(Wrand(:,idx)')*H*Vrand(:,idx) + nvec(idx); 
                Arand(idx, :)       = kron(Vrand(:,idx).', Wrand(:,idx)'); 
            end
            
            h_hat                   = (sigma2*eye(Nh) + P*(Arand'*Arand))\(sqrt(P)*Arand'*yrand); 
            H_hat                   = reshape(h_hat, [NRx, NTx]); 
            inner_NMSE3(i_test)     = norm(H - H_hat, 'fro')^2/norm(H, 'fro')^2; 

            % [Method 4] AMP
            S                       = kron(conj(dftmtx(NTx)/sqrt(NTx)), (dftmtx(NRx)/sqrt(NRx))); 
            Phi                     = Arand*S; 
            x_hat                   = camp(yrand/sqrt(P), Phi, 0.90); 
            h_hat                   = S*x_hat;
            H_hat                   = reshape(h_hat, [NRx, NTx]); 
            inner_NMSE4(i_test)     = norm(H - H_hat, 'fro')^2/norm(H, 'fro')^2; 
            
            % [Method 7] Bandwidth Estimating PSWF
            inner_NMSE7(i_test)     = AdaptivePSWFestimator(H, P, nvec, MIMO, "MIMO-BWest-PSWF"); 
            
            % Method 5,6 not implemented. 
        end
    end
    timeElapsed = toc(); 
    
    inner_NMSE          = [inner_NMSE1, inner_NMSE2, inner_NMSE3, inner_NMSE4, inner_NMSE5, inner_NMSE6, inner_NMSE7]; 
    NMSE(i_snr, :)      = mean(inner_NMSE, 1);
    StdNMSE(i_snr, :)   = std(inner_NMSE, 0, 1); 
    StdNMSEdB(i_snr, :) = std(pow2db(inner_NMSE), 0, 1); 

    fprintf('|  %6.2f  |     % 6.2f    |    % 6.2f     |   % 6.2f   |% 6.2f |  % 6.2f  | % 6.2f  |   %6.2f   |      %6.3f     |\n', ...
        SNR_arr(i_snr), ...
        pow2db(NMSE(i_snr, 1)), ...
        pow2db(NMSE(i_snr, 2)), ...
        pow2db(NMSE(i_snr, 3)), ...
        pow2db(NMSE(i_snr, 4)), ...
        pow2db(NMSE(i_snr, 5)), ...
        pow2db(NMSE(i_snr, 6)), ...
        pow2db(NMSE(i_snr, 7)), ...
        timeElapsed); 
end
fprintf('+----------+---------------+---------------+------------+-------+----------+---------+------------+-----------------+\n');
fprintf('Simulation complete...\n');


%% Visualization

set(0,'DefaultLineMarkerSize',  6);
set(0,'DefaultTextFontSize',    14);
set(0,'DefaultAxesFontSize',    12);
set(0,'DefaultLineLineWidth',   1.4);
set(0,'defaultfigurecolor',     'w');

figure('Renderer', 'painters', 'Position', [300 100 700 550], 'Color', [1 1 1]);  % format: (x, y, width, height) in pixels. 

saveFigs = false;
saveFiles = false; 
errorBar = true; 

if MIMO.isSIMO
    methodSel = [3, 4, 5, 6, 7, 2, 1]; 
else
    methodSel = [3, 4, 7, 2, 1]; 
end


if ~errorBar
    plot(SNR_arr, NMSE(:,3), 'Color','b', 'Marker','v'); hold on; grid on; box on;  
    plot(SNR_arr, NMSE(:,4), 'Color', [0 0.6 0.6] , 'Marker','<'); 
    plot(SNR_arr, NMSE(:,5), 'Color', [0.8, 0.2, 0.8], 'Marker','o'); 
    plot(SNR_arr, NMSE(:,6), 'Color', [1, 0.3, 0], 'Marker', 'square'); 
    plot(SNR_arr, NMSE(:,7), 'Color', [1, 0, 0], 'Marker','diamond'); 
    plot(SNR_arr, NMSE(:,2), 'Color', [0.7, 0, 0], 'LineStyle','-', 'Marker', 'pentagram'); 
    plot(SNR_arr, NMSE(:,1), 'Color', 'k', 'LineStyle','--'); 
    set(gca, 'yscale', 'log');
    xlabel('SNR (dB)'); 
    ylabel('NMSE'); 

else
    if MIMO.isSIMO
        errorbar(SNR_arr, pow2db(NMSE(:,3)), StdNMSEdB(:,3), 'Color','b', 'Marker','v', 'LineWidth', 1);          hold on; grid on; box on;  
        errorbar(SNR_arr, pow2db(NMSE(:,4)), StdNMSEdB(:,4), 'Color', [0 0.6 0.6] , 'Marker','<', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,5)), StdNMSEdB(:,5), 'Color', [0.8, 0.2, 0.8], 'Marker','o', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,6)), StdNMSEdB(:,6), 'Color', [1, 0.3, 0], 'Marker', 'square', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,7)), StdNMSEdB(:,7), 'Color', [1, 0, 0], 'Marker','diamond', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,2)), StdNMSEdB(:,2), 'Color', [0.7, 0, 0], 'LineStyle','-', 'Marker', 'pentagram', 'MarkerSize', 9); 
        errorbar(SNR_arr, pow2db(NMSE(:,1)), StdNMSEdB(:,1), 'Color', 'k', 'LineStyle','--', 'Marker','x'); 
        xlabel('SNR (dB)'); 
        ylabel('NMSE (dB)'); 
        ylim([-40, 9]);

    else
        errorbar(SNR_arr, pow2db(NMSE(:,3)), StdNMSEdB(:,3), 'Color','b', 'Marker','v', 'LineWidth', 1);          hold on; grid on; box on;  
        errorbar(SNR_arr, pow2db(NMSE(:,4)), StdNMSEdB(:,4), 'Color', [0 0.6 0.6] , 'Marker','<', 'LineWidth', 1); 
        % errorbar(Ncombiner_arr, pow2db(NMSE(:,5)), StdNMSEdB(:,5), 'Color', [0.8, 0.2, 0.8], 'Marker','o', 'LineWidth', 1); 
        % errorbar(Ncombiner_arr, pow2db(NMSE(:,6)), StdNMSEdB(:,6), 'Color', [1, 0.3, 0], 'Marker', 'square', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,7)), StdNMSEdB(:,7), 'Color', [1, 0, 0], 'Marker','diamond', 'LineWidth', 1); 
        errorbar(SNR_arr, pow2db(NMSE(:,2)), StdNMSEdB(:,2), 'Color', [0.7, 0, 0], 'LineStyle','-', 'Marker', 'pentagram', 'MarkerSize', 9); 
        errorbar(SNR_arr, pow2db(NMSE(:,1)), StdNMSEdB(:,1), 'Color', 'k', 'LineStyle','--', 'Marker','x'); 
        xlabel('SNR (dB)'); 
        ylabel('NMSE (dB)'); 
        ylim([-22, 6]); 
    end

end

% line([36, 36], [1e-3, 1e0], 'linestyle', ':', 'Color', 'k'); 
methodNames = { 'PSWF w/ Stat CSI MMSE (Oracle)', ...
                'PSWF w/o Stat CSI MMSE', ...
                'Random Comb. MMSE', ...
                'Random Comb. AMP', ...
                'SBAR', ...
                'Adaptive PSWF', ...
                'BWest PSWF', ...
                }; 
legend(methodNames(methodSel), "FontSize", 10, 'Location','northeast'); 

if saveFigs
    if MIMO.isSIMO
        exportgraphics(gcf, 'results/Estimators/SIMO_CE_Performance_wSNR.pdf', 'ContentType','vector');
    else
        exportgraphics(gcf, 'results/Estimators/MIMO_CE_Performance_wSNR.pdf', 'ContentType','vector');
    end

    if saveFiles
        save('results/Estimators/data.mat'); 
        fprintf('Files saved.\n'); 
    end
    fprintf('Figs saved.\n'); 
end

%% Utils
function plotSIMOkRegion(SIMO_params, h)
    figure(); 
    dat =abs(h'*SIMO_params.RxArray.dict).^2.';
    dat = dat/norm(dat); 

    M = max(dat); 
    plot(SIMO_params.nkGrids(:,1), dat); hold on; 

    for idx = 1:length(SIMO_params.RxArray.RegionsRx)
        x = SIMO_params.RxArray.RegionsRx{idx}(1) - SIMO_params.RxArray.RegionsRx{idx}(3); 
        line([x, x], [0, M], 'LineStyle', '--', 'Color', 'r');

        y = SIMO_params.RxArray.RegionsRx{idx}(1) + SIMO_params.RxArray.RegionsRx{idx}(3); 
        line([y, y], [0, M], 'LineStyle', '--', 'Color', 'b');

        fprintf('Recv nk region %d: [%.2f ~ %.2f] \n', idx, x, y); 
    end

    xlim([-1.1, 1.1]); 
    xlabel('Normalized wavenumber k'); 
    ylabel('Relative Intensity'); 
    grid on; box on; 

end


