% ============================== % 
% This file studies the DoF/Capacity of the proposed colored fading channel
% model 
%   
%    Output: Gap v.s. SNR and aperture size. 
%    Regime: XL-MIMO with (lambda/2)-spacing. 
% 
% 
% ============================== % 
clear all; close all; clc; 

addpath('algs/'); 
addpath('utils/');
fc      = 3.5e9; 
c       = physconst('lightspeed'); 
lambda  = c/fc; 
k0      = 2*pi/lambda; 

Gamma   = 0.05; 

RegionsTx       = cell(1, 1);
RegionsTx{1}    = [-0.4, 0, 0.45, Gamma];       % format: (nkx, nky, nkradius, gammaCorrelationFactor). 
% RegionsTx{2}    = [0.3, 0, 0.15, Gamma];
% RegionsTx{3}    = [0.7, 0, 0.15, Gamma]; 

RegionsRx       = cell(1,1); 
RegionsRx{1}    = [-0.3, 0, 0.45, Gamma]; 
% RegionsRx{2}    = [0.1, 0, 0.15, Gamma];
% RegionsRx{3}    = [0.52, 0, 0.15, Gamma];

reg.RegionsTx   = RegionsTx; 
reg.RegionsRx   = RegionsRx; 
reg.Gamma       = Gamma; 

% P       = input('Enter transmit power P:\n');
% P = 10; 
sigma2  = 1; 


%% Get channel and precode. 
rng(0);

L_arr   = (0.1:0.2:5); 
SNR_arr = (0:2:20); 

N_L = length(L_arr); N_SNR = length(SNR_arr); 

N_scan  = N_L*N_SNR;  

N_test  = 100; 

maxRate                 = zeros(N_scan, 1); 
maxRateStd              = zeros(N_scan, 1); 
lbRate                  = zeros(N_scan, 1); 
lbRateStd               = zeros(N_scan, 1); 
avgDoF                  = zeros(N_scan, 1);
avgDoFStd               = zeros(N_scan, 1); 

theoreticalDoF          = zeros(N_scan, 1);
theoreticalRelativeMI   = zeros(N_scan, 2); 

% PSWF Bound
OmegaTx = 0; 
for idx = 1:length(RegionsTx)
    OmegaTx = OmegaTx + 2*RegionsTx{idx}(3); 
end

OmegaRx = 0; 
for idx = 1:length(RegionsRx)
    OmegaRx = OmegaRx + 2*RegionsRx{idx}(3); 
end

Omega = sqrt(OmegaRx*OmegaTx); 


fprintf('Simulation with half-wavelength antenna spacing\n');  
fprintf('+============+===============+=======================================================+===================+\n');
fprintf('|            |               |               Normalized  Rate (bps/Hz)               |                   |\n');
fprintf('|  Progress  |   (Nr, Nt)    +-------------------------------------------------------+  Elapsed Time (s) |\n');
fprintf('|            |               | max Rate | lower bound |  average DoF  |  PSWF bound  |                   |\n');
fprintf('+------------+---------------+----------+-------------+---------------+--------------+-------------------+\n');


for i_scan = 1:N_scan
    L_idx = mod(i_scan-1, N_L)+1; 
    SNR_idx = floor((i_scan-1)/N_L)+1; 

    L   = L_arr(L_idx); 
    SNR = SNR_arr(SNR_idx); 
    P   = db2pow(SNR); 

    MIMO_params = setup("SymmetricULAwithFixedK", fc, L, reg);           % MIMO setup
    MIMO_params.memorizeCorrelation     = true; 
    MIMO_params.FreqShift               = false; 
    MIMO_params.FreqShiftFactor         = 40; 
    MIMO_params.sigma2                  = sigma2; 
    MIMO_params.P                       = P;
    MIMO_params.debug                   = false; 

    [ECpswf, ECa, aDoF, NyqNumber]      = getPSWFBound(OmegaTx, L/lambda, Gamma, P/sigma2); 
    theoreticalDoF(i_scan)              = aDoF; 
    theoreticalRelativeMI(i_scan,1)     = ECa; 
    theoreticalRelativeMI(i_scan,2)     = ECpswf; 
    
    inner_maxRate       = zeros(N_test, 1); 
    inner_lbRate        = zeros(N_test, 1); 
    inner_avgDoF        = zeros(N_test, 1); 
    
    tic; 
    for i_test = 1:N_test
        H = getChannel(MIMO_params, "ColoredAngularCorrelation"); 
        
        H = H * (1/2);      % Energy normalization due to lambda/2 spacing. 
        % [NMSE, ~] = estimateChannel(H, MIMO_params, "PSWF"); 

        % Precode for MIMO channel H. 
        [W, U, p] = waterFilling(H, P, sigma2); 
        
        inner_maxRate(i_test) = log(real(det(eye(MIMO_params.RxArray.N)+(1/sigma2)*(H*W*diag(p)*W'*H'))))/log(1+MIMO_params.P/sigma2); 
        inner_lbRate(i_test) = log(real(det(eye(MIMO_params.RxArray.N)+(MIMO_params.P/(MIMO_params.TxArray.N*sigma2))*(H*H'))))/log(1+MIMO_params.P/sigma2); 
        inner_avgDoF(i_test) = sum(p>0); 
    end
    timeElapsed = toc(); 
    
    maxRate(i_scan)     = mean(inner_maxRate); 
    lbRate(i_scan)      = mean(inner_lbRate);
    avgDoF(i_scan)      = mean(inner_avgDoF); 

    maxRateStd(i_scan)  = std(inner_maxRate); 
    lbRateStd(i_scan)   = std(inner_lbRate); 
    avgDoFStd(i_scan)   = std(inner_avgDoF); 

    fprintf('|  %3d/%3d   |  (%3d, %3d)   |  % 6.2f  |   % 6.2f    |    % 6.2f     |    %6.2f    |      % 7.3f      |\n', ...
        i_scan, N_scan,     ...
        MIMO_params.RxArray.N,  MIMO_params.TxArray.N, ...
        maxRate(i_scan),    ...
        lbRate(i_scan),     ...
        avgDoF(i_scan),     ...
        theoreticalRelativeMI(i_scan, 2), ...
        timeElapsed); 
end
fprintf('+============+===============+=======================================================+===================+\n');
fprintf('Simulation successfully completed.\n'); 


%% Plots
close all; 

saveFigs = true; 

set(0,'DefaultLineMarkerSize',  6);
set(0,'DefaultTextFontSize',    14);
set(0,'DefaultAxesFontSize',    12);
set(0,'DefaultLineLineWidth',   1.4);
set(0,'defaultfigurecolor',     'w');


% Plot the capacity difference as a 2D function . 
fig1 = figure(1);
subplot(2, 1, 1); 
[X, Y] = meshgrid(L_arr, SNR_arr); 
Z       = zeros(N_SNR, N_L); 
Z_std   = zeros(N_SNR, N_L); 
for idx_scan = 1:N_scan
    L_idx = mod(idx_scan-1, N_L)+1; 
    SNR_idx = floor((idx_scan-1)/N_L)+1; 
    Z(SNR_idx, L_idx) = (theoreticalRelativeMI(idx_scan, 2) - maxRate(idx_scan))/maxRate(idx_scan);
    Z_std(SNR_idx, L_idx) = maxRateStd(idx_scan)/maxRate(idx_scan); 
end
surf(X, Y, Z); shading interp;  
colormap('jet'); 
colorbar; 
% plot(L_arr.', maxRate, 'Marker','square', 'color', [0 0 1]); grid on; hold on; box on; 
% plot(L_arr.', lbRate, 'Marker','v', 'color', [1, 0, 0]); 
% plot(L_arr.', theoreticalRelativeMI(:,2), 'Linestyle', '--', 'Color', 'k'); 
xlabel('Tx Aperture (m)'); 
ylabel('SNR (dB)'); 
title('Relative capacity gap'); 
view(0, 90); 

subplot(2, 1, 2); 
surf(X, Y, Z_std); shading interp; 
colormap('jet'); colorbar; 
xlabel('Tx Aperture (m)'); 
ylabel('SNR (dB)'); 
title('Relative capacity uncertainty');  
view(0, 90); 

if saveFigs
    exportgraphics(fig1, 'results/Response/XLMIMO_capacity_gap.pdf', 'ContentType','vector'); 
end


% Plot with errorbars. 
% figure(2); 
% plot(L_arr.', theoreticalRelativeMI(:,2), 'Linestyle', '--', 'Color', 'k'); grid on; hold on; box on; 
% errorbar(L_arr.', maxRate, maxRateStd, 'Marker','square', 'color', [0 0 1], 'LineWidth', 1); 
% errorbar(L_arr.', lbRate, lbRateStd, 'Marker','v', 'color', [1, 0, 0], 'LineWidth', 1); 
% xlabel('Tx Aperture [m]'); 
% ylabel('Normalized Capacity'); 
% legend({'Ergodic Capacity (PSWF bound)', 'SVD+WaterFilling', 'Equi-power'}, 'Location','best'); 
% if saveFigs
%     exportgraphics(gcf, 'results/XLMIMOregime/XLMIMO_capacityWithErrorBars.pdf', 'ContentType','vector'); 
% end


% Plot the DoF. 
% figure(3); 
% plot(L_arr.', theoreticalRelativeMI(:,2), 'LineStyle','--', 'Color','k'); hold on; grid on; box on; 
% plot(L_arr.', avgDoF, 'Marker','<', 'color', [0, 0.7, 0.7]);  
% xlabel('Tx Aperture [m]'); 
% ylabel('Effective DoF'); 
% legend({'PSWF-DoF Approx.', 'Empirical DoF'}, 'Location','best'); 
% 
% if saveFigs
%     exportgraphics(gcf, 'results/XLMIMOregime/XLMIMO_DoF.pdf', 'ContentType','vector'); 
%     save('results/XLMIMOregime/data.mat'); 
%     fprintf('Files saved at results/XLMIMOregime/.\n'); 
% end

fprintf('End of this script.\n'); 

%% Genereta channel H using LoS near-field assumption

% H = zeros(UPA.Nx*UPA.Ny); 
% 
% for idx = 1:UPA.Nx*UPA.Ny
%     ix = mod(idx-1, UPA.Ny) + 1;
%     iy = (idx - ix)/UPA.Ny + 1;
%     sep = upa(1).coord(ix, iy, :) - upa(2).coord;
% 
%     h = exp(-1i*k0*sqrt(sum(sep.^2, 3))); 
%     H(:, idx) = h(:); 
% end
% 
% H = H/norm(H, 'fro');
% [U, S, V] = svd(H); 
% 
% s = diag(S);
% smax = max(s); 
% 
% 
% 
% RD = 2*(sqrt(upa(1).Lx^2+upa(1).Ly^2))/lambda;
% 
% csq = 4*pi/(lambda*dsep)^2*(upa(1).Lx*upa(1).Ly)*(upa(2).Lx*upa(2).Ly); 
% NDOF_PSWF = csq/4;


%% 
% figure(1); 
% imagesc(real(reshape(V(:,4), [UPA.Nx, UPA.Ny]))); axis equal; 
% colorbar; 
% 
% figure(2); 
% N_lambda = floor(NDOF_PSWF); 
% plot((1:N_lambda).', mag2db(s(1:N_lambda)/smax)); grid on; box on; 
