clear all 
clc

% addpath("Kernel\")
addpath("CS solvers\")
%%% system parameters
f = 3.5e9; % carrier frequency
c = 3e8;
lambda = c/f;
N1 = 16;     % number of horizontal antennas
N2 = 8;      % number of vertical antennas
N = N1 * N2; % number of antennas 
M = N;       % number of pilots for LS and MMSE
Q = 64;      % number of pilot for GPR and OMP
P = Q;       % Total power allocated to all pilots
SNR_dB = 0;  SNR_linear=10.^(SNR_dB/10.);
sigma2=1/SNR_linear;
N_iter = 100;
s = 10;

% lambda_over_d_list = linspace(2,16,32);
d_list = lambda * linspace(1/16,1.5,16);

%2D-DFT      
D1 = 2*N1; %字典规模
D2 = 2*N2;
D = D1 * D2;
row1 = (-(N1 - 1)/2:(N1 - 1)/2)' ;
col1 = -1 + 2/D1 : 2/D1 : 1 ;
row2 = (-(N2 - 1)/2:(N2 - 1)/2)' ;
col2 = -1 + 2/D2 : 2/D2 : 1 ;
F1 = exp(1j*  pi * row1 * col1)/sqrt(N1);
F2 = exp(1j*  pi * row2 * col2)/sqrt(N2);
DFT2 =  kron(F1, F2);



NMSE_LS   = zeros(1,length(d_list));

NMSE_OMP = zeros(1,length(d_list));
NMSE_AMP = zeros(1,length(d_list));
NMSE_LASSO = zeros(1,length(d_list));

NMSE_rand = zeros(1,length(d_list));
NMSE_TopQ = zeros(1,length(d_list));
NMSE_QWF = zeros(1,length(d_list));


NMSE_MM = zeros(1,length(d_list));
NMSE_WF = zeros(1,length(d_list));



d_len = length(d_list);
t0 = clock;

parfor i_d=1:d_len 

    d = d_list(i_d);
    [Kernel_J0, Kernel_exp, Kernel_SV] = Kernal_generate_func(f,N1,N2,d);
    kernel_conv = Kernel_SV;
    % Observation matrix
    W_WF    = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'WF');
    W_QWF   = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'QWF');
    W_MM    = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'MM');
    W_opt = exp(1j*2*pi* (0:1:M-1)' * (0:1:N-1)/M)/sqrt(N);

    % TopQ eigenvectors for OMP/AMP/LMMSE
    [V, D] = eig(kernel_conv);
    W_TopQ = V(:, M-Q+1:M) * sqrt(P/Q); 


    for iter = 1:N_iter
        if mod(iter, 100) == 0
            fprintf('d = %.4f[%d/%d] | iteration:[%d/%d] | run %.4f s\n', d, i_d, d_len, iter, N_iter, etime(clock, t0));
        end
        
        % Generate nomalized channel
        h = SV_channel(f,N1,N2,d);
        h = sqrt(N) * h /norm(h);

        % Generate noise
        n_true = sqrt(sigma2) * ( randn(Q,1) + 1i*randn(Q,1) ) / sqrt(2);     % Full pilot  noise
        n_opt  = sqrt(sigma2) * ( randn(M,1) + 1i*randn(M,1) ) / sqrt(2);     % Reduced pilot noise
         


        % LS
        z_opt = W_opt' * h + n_opt;
        h_LS = pinv(W_opt')*z_opt;

        % GPR WF
        z_WF = W_WF' * h + n_true;
        h_WF = kernel_conv * W_WF * pinv(W_WF'*kernel_conv*W_WF + sigma2 * eye(Q))*z_WF;
        % GPR QWF
        z_QWF = W_QWF' * h + n_true;
        h_QWF = kernel_conv * W_QWF * pinv(W_QWF'*kernel_conv*W_QWF + sigma2 * eye(Q))*z_QWF;
        % GPR MM
        z_MM = W_MM' * h + n_true;
        h_MM = kernel_conv * W_MM * pinv(W_MM'*kernel_conv*W_MM + sigma2 * eye(Q))*z_MM;
         % GPR TopQ
        z_TopQ = W_TopQ' * h + n_true;
        h_TopQ = kernel_conv * W_TopQ * pinv(W_TopQ'*kernel_conv*W_TopQ + sigma2 * eye(Q))*z_TopQ;
        % GPR rand
        W_rand  = (randn(N, Q) + 1j* randn(N, Q));
        W_rand  = W_rand * sqrt(P) / norm(W_rand(:));
        z_rand = W_rand' * h + n_true;
        h_rand = kernel_conv * W_rand * pinv(W_rand'*kernel_conv*W_rand + sigma2 * eye(Q))*z_rand;
       
        

        % VAMP
        z_AMP = W_rand' * h + n_true;
        hbar_AMP = VAMP(z_AMP, W_rand'*DFT2, 1.2);
        h_AMP = DFT2*hbar_AMP;

        % OMP
        z_OMP = W_rand' * h + n_true;
        hbar_OMP = OMP(z_OMP, W_rand'*DFT2, s);
        h_OMP = DFT2*hbar_OMP;

        % % LASSO + TopQ
        % z_LASSO = W_TopQ' * h + n_true;
        % hbar_LASSO = LASSO(z_LASSO, W_TopQ'*DFT2, 1);
        % h_LASSO= DFT2*hbar_LASSO;
        
        NMSE_LS(i_d)   = NMSE_LS(i_d) + norm(h_LS - h)^2/norm(h)^2/N_iter;

        NMSE_OMP(i_d)   = NMSE_OMP(i_d) + norm(h_OMP - h)^2/norm(h)^2/N_iter;
        NMSE_AMP(i_d)   = NMSE_AMP(i_d) + norm(h_AMP - h)^2/norm(h)^2/N_iter;
        % NMSE_LASSO(i_d)   = NMSE_LASSO(i_d) + norm(h_LASSO - h)^2/norm(h)^2/N_iter;

        NMSE_MM(i_d)   = NMSE_MM(i_d) + norm(h_MM - h)^2/norm(h)^2/N_iter;
        NMSE_QWF(i_d)   = NMSE_QWF(i_d) + norm(h_QWF - h)^2/norm(h)^2/N_iter;
        NMSE_WF(i_d)   = NMSE_WF(i_d) + norm(h_WF - h)^2/norm(h)^2/N_iter;
        NMSE_rand(i_d)   = NMSE_rand(i_d) + norm(h_rand - h)^2/norm(h)^2/N_iter;
        NMSE_TopQ(i_d)   = NMSE_TopQ(i_d) + norm(h_TopQ - h)^2/norm(h)^2/N_iter;
      

    end
end


figure;grid on; box on; hold on;
p2 = plot(d_list/lambda, 10*log10(NMSE_AMP),  '-^', 'Linewidth', 1.2, 'Color','b');
p3 = plot(d_list/lambda, 10*log10(NMSE_OMP),  '-d', 'Linewidth', 1.2, 'color',[237,177,32]/256);
p1 = plot(d_list/lambda, 10*log10(NMSE_LS),  '--', 'Linewidth', 1.2, 'color', 'k');
p5 = plot(d_list/lambda, 10*log10(NMSE_rand),  '-x', 'Linewidth', 1.2, 'color', [126,47,142]/256);
p6 = plot(d_list/lambda, 10*log10(NMSE_TopQ),  '-*', 'Linewidth', 1.2, 'color', [162,20,47]/256);
p7 = plot(d_list/lambda, 10*log10(NMSE_MM),  '-s', 'Linewidth', 1.2, 'color', [119,172,48]/256);
p8 = plot(d_list/lambda, 10*log10(NMSE_QWF),  '-o', 'Linewidth', 1.2, 'color', [217,83,25]/256);
p9 = plot(d_list/lambda, 10*log10(NMSE_WF),  '--+', 'Linewidth', 1.2, 'color', [0,114,189]/256);


grid on
xlabel('The ratio of antenna spacing and wavelength $d/\lambda$','Interpreter','latex','FontSize',14);
ylabel('NMSE [dB]','Interpreter','latex','FontSize',14);
legend([p1,p3,p2,p5,p6,p7,p8,p9], {'LS', 'OMP + $\mathbf{W}^{{\rm Top}Q}$', 'VAMP + $\mathbf{W}^{{\rm Rand}}$', ...
     'MMSE + $\mathbf{W}^{{\rm Rand}}$', 'MMSE + $\mathbf{W}^{{\rm Top}Q}$', ...
    'MMSE + $\mathbf{W}^{{\rm MM}}$ (ours)','MMSE + $\mathbf{W}^{{\rm IF}}$ (ours)','MMSE + $\mathbf{W}^{{\rm WF}}$ (ideal)'}, ...
    'FontSize',12,'Interpreter','latex');
xlim([min(d_list/lambda), max(d_list/lambda)])
