% ============================== % 
% This file tests channel estimators. 
% model. 
% 
% ============================== % 
clear; close all; clc; 

addpath('algs/'); 
addpath('utils/');
fc = 3.5e9; 
c = physconst('lightspeed'); 
lambda = c/fc; 
k0 = 2*pi/lambda; 

Gamma = 0.05; 

RegionsTx = cell(1,1);
RegionsTx{1} = [-0.4, 0, 0.15, Gamma];   % format: (nkx, nky, nkradius, gammaCorrelationFactor). 
RegionsTx{2} = [0.3, 0, 0.15, Gamma];
RegionsTx{3} = [0.7, 0, 0.15, Gamma]; 

RegionsRx = cell(1,1); 
RegionsRx{1} = [-0.85, 0, 0.15, Gamma]; 
RegionsRx{2} = [0.1, 0, 0.15, Gamma];
RegionsRx{3} = [0.52, 0, 0.15, Gamma];

P = 10;
sigma2 = 1; 


%% Get channel and precode. 
rng(0);

L_arr = [1.9000    1.7000    1.5000    1.3000    1.1000    0.9000    0.7000    0.5000    0.3000    0.1000]; 
N_scan = length(L_arr); 

N_test = 20; 

maxRate = zeros(N_scan, 1); 
lbRate = zeros(N_scan, 1); 
avgDoF = zeros(N_scan, 1);
NMSE = zeros(N_scan, 2); 


for i_scan = 1:N_scan
    MIMO_params = setup("SymmetricULAwithFixedK", fc, L_arr(i_scan), Gamma);           % MIMO setup
    MIMO_params.TxArray.RegionsTx = RegionsTx; 
    MIMO_params.RxArray.RegionsRx = RegionsRx; 

    MIMO_params.FreqShift = false; 
    MIMO_params.FreqShiftFactor = 10; 
    MIMO_params.sigma2 = sigma2; 
    MIMO_params.P = P;
    MIMO_params.P_channelEst = P; 
    MIMO_params.Npilot = 300; 
    MIMO_params.debug = false; 
    
    inner_maxRate = zeros(N_test, 1); 
    inner_lbRate = zeros(N_test, 1); 
    inner_avgDoF = zeros(N_test, 1); 

    inner_NMSE = zeros(N_test, 2); 

    for i_test = 1:N_test
        [H, correlationData] = getChannel(MIMO_params, "ColoredAngularCorrelation"); 
        
        [inner_NMSE(i_test, 1), ~] = estimateChannel(H, correlationData, MIMO_params, "RandomPilotMMSE"); 
        [inner_NMSE(i_test, 2), ~] = estimateChannel(H, correlationData, MIMO_params, "PSWFPilotMMSE");


        % Precode for MIMO channel H. 
       
        [W, U, p] = waterFilling(H, P, sigma2); 
        
        inner_maxRate(i_test) = log(real(det(eye(MIMO_params.RxArray.N)+(1/sigma2)*(H*W*diag(p)*W'*H'))))/log(1+MIMO_params.P/sigma2); 
        inner_lbRate(i_test) = log(real(det(eye(MIMO_params.RxArray.N)+(MIMO_params.P/(MIMO_params.TxArray.N*sigma2))*(H*H'))))/log(1+MIMO_params.P/sigma2); 
        inner_avgDoF(i_test) = sum(p>0); 
    end
    
    maxRate(i_scan) = mean(inner_maxRate); 
    lbRate(i_scan) = mean(inner_lbRate);
    avgDoF(i_scan) = mean(inner_avgDoF); 

    NMSE(i_scan, 1) = mean(inner_NMSE(:,1)); 
    NMSE(i_scan, 2) = mean(inner_NMSE(:,2)); 
    
    fprintf('Scan %d/%d complete.\n', i_scan, N_scan); 
end


%% Plots
close all; 
figure(2);
plot(L_arr.'/(lambda/2), maxRate); grid on; hold on;
plot(L_arr.'/(lambda/2), lbRate); 

xlabel('Number of antennas'); 
ylabel('Information (nat/s/Hz)'); 
legend({'Optimal SVD+WaterFilling', 'MI with equi-power (lower bound)'}); 

figure(3);
plot(L_arr.', maxRate, 'Marker','square', 'color', [0 0 1]); grid on; hold on;
plot(L_arr.', lbRate, 'Marker','v', 'color', [1, 0, 0]); 
plot(L_arr.', avgDoF, 'Marker','<', 'color', [0, 0.7, 0.7]); 

theoreticalDoF = zeros(N_scan, 1);
theoreticalRelativeMI = zeros(N_scan, 1); 

OmegaTx = 0; 
for idx = 1:length(RegionsTx)
    OmegaTx = OmegaTx + 2*RegionsTx{idx}(3); 
end

OmegaRx = 0; 
for idx = 1:length(RegionsRx)
    OmegaRx = OmegaRx + 2*RegionsRx{idx}(3); 
end

Omega = sqrt(OmegaRx*OmegaTx); 


for idx = 1:N_scan
    theoreticalDoF(idx) = Omega*min(L_arr(idx)/lambda, 1/Gamma); 
    f = log(P/sigma2)/(2*pi^2); 
    theoreticalRelativeMI(idx) = theoreticalDoF(idx) + 3*log(2*pi*theoreticalDoF(idx))*f; 
end
plot(L_arr.', theoreticalRelativeMI, 'Linestyle', ':', 'Color', 'k'); 

xlabel('Tx Aperture [m]'); 
ylabel('Effective DoF'); 
legend({'N-MI: Optimal SVD+WaterFilling', ...
    'N-MI: Equi-power (lower bound)', ...
    'DoF: Empirical', ...
    'DoF: Theoretical Approx.'}); 
