clear all 
clc
addpath('./CS solvers/')
Q_list = 8:4:68; % number of pilots for GPR and OMP
N_iter = 10; 
SNR_dB=5;  SNR_linear=10.^(SNR_dB/10.);
sigma2=1/SNR_linear;


%%% system parameters
f = 3.5e9; % carrier frequency
c = 3e8;
lambda = c/f;
d = lambda/8;
N1 = 16;     % number of horizontal antennas
N2 = 8;      % number of vertical antennas
N = N1 * N2; % number of antennas 
M = N;       % number of pilots for LS and MMSE

%2D-DFT      
D1 = 2*N1; %字典规模
D2 = 2*N2;
D = D1 * D2;
row1 = (-(N1 - 1)/2:(N1 - 1)/2)' ;
col1 = -1 + 2/D1 : 2/D1 : 1 ;
row2 = (-(N2 - 1)/2:(N2 - 1)/2)' ;
col2 = -1 + 2/D2 : 2/D2 : 1 ;
F1 = exp(1j*  pi * row1 * col1)/sqrt(N1);
F2 = exp(1j*  pi * row2 * col2)/sqrt(N2);
DFT2 =  kron(F1, F2);


[kernel_J0, kernel_exp, kernel_conv] = Kernal_generate_func(f,N1,N2,d);



NMSE_LS   = zeros(1,length(Q_list));

NMSE_OMP = zeros(1,length(Q_list));
NMSE_AMP = zeros(1,length(Q_list));
NMSE_LASSO = zeros(1,length(Q_list));

NMSE_rand = zeros(1,length(Q_list));
NMSE_TopQ = zeros(1,length(Q_list));
NMSE_QWF = zeros(1,length(Q_list));


NMSE_MM = zeros(1,length(Q_list));
NMSE_WF = zeros(1,length(Q_list));





Q_len = length(Q_list);
t0 = clock;

parfor i_q=1:Q_len 
    Q = Q_list(i_q); 
    P = Q;
    % s = Q/2;
    s = min(Q, 10);
    % Observation matrix
    W_WF    = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'WF');
    W_QWF   = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'QWF');
    W_MM    = GPR(kernel_conv, Q, N, P, sqrt(sigma2), 'MM');
    W_opt = exp(1j*2*pi* (0:1:M-1)' * (0:1:N-1)/M)/sqrt(N);

    % TopQ eigenvectors for OMP/AMP/LMMSE
    [V, D] = eig(kernel_conv);
    W_TopQ = V(:, M-Q+1:M) * sqrt(P/Q); 

    for iter = 1:N_iter
        if mod(iter, 50) == 0
            fprintf('Pilot = %d[%d/%d] | iteration:[%d/%d] | run %.4f s\n', Q, i_q, Q_len, iter, N_iter, etime(clock, t0));
        end
        
        % Generate nomalized channel
        h = SV_channel(f,N1,N2,d);
        h = sqrt(N) * h /norm(h);

        % Generate noise
        n_true = sqrt(sigma2) * ( randn(Q,1) + 1i*randn(Q,1) ) / sqrt(2);     % Full pilot  noise
        n_opt  = sqrt(sigma2) * ( randn(M,1) + 1i*randn(M,1) ) / sqrt(2);     % Reduced pilot noise
         


        % LS
        z_opt = W_opt' * h + n_opt;
        h_LS = pinv(W_opt')*z_opt;

        % GPR WF
        z_WF = W_WF' * h + n_true;
        h_WF = kernel_conv * W_WF * pinv(W_WF'*kernel_conv*W_WF + sigma2 * eye(Q))*z_WF;
        % GPR QWF
        z_QWF = W_QWF' * h + n_true;
        h_QWF = kernel_conv * W_QWF * pinv(W_QWF'*kernel_conv*W_QWF + sigma2 * eye(Q))*z_QWF;
        % GPR MM
        z_MM = W_MM' * h + n_true;
        h_MM = kernel_conv * W_MM * pinv(W_MM'*kernel_conv*W_MM + sigma2 * eye(Q))*z_MM;
         % GPR TopQ
        z_TopQ = W_TopQ' * h + n_true;
        h_TopQ = kernel_conv * W_TopQ * pinv(W_TopQ'*kernel_conv*W_TopQ + sigma2 * eye(Q))*z_TopQ;
        % GPR rand
        W_rand  = (randn(N, Q) + 1j* randn(N, Q));
        W_rand  = W_rand * sqrt(P) / norm(W_rand(:));
        z_rand = W_rand' * h + n_true;
        h_rand = kernel_conv * W_rand * pinv(W_rand'*kernel_conv*W_rand + sigma2 * eye(Q))*z_rand;
       
        

        % VAMP
        z_AMP = W_rand' * h + n_true;
        hbar_AMP = VAMP(z_AMP, W_rand'*DFT2, 1.2);
        h_AMP = DFT2*hbar_AMP;

        % OMP
        z_OMP = W_TopQ' * h + n_true;
        hbar_OMP = OMP(z_OMP, W_TopQ'*DFT2, s);
        h_OMP = DFT2*hbar_OMP;

        % LASSO + TopQ
        z_LASSO = W_TopQ' * h + n_true;
        hbar_LASSO = LASSO(z_LASSO, W_TopQ'*DFT2, 1);
        h_LASSO= DFT2*hbar_LASSO;

      
        
        NMSE_LS(i_q)   = NMSE_LS(i_q) + norm(h_LS - h)^2/norm(h)^2/N_iter;

        NMSE_OMP(i_q)   = NMSE_OMP(i_q) + norm(h_OMP - h)^2/norm(h)^2/N_iter;
        NMSE_AMP(i_q)   = NMSE_AMP(i_q) + norm(h_AMP - h)^2/norm(h)^2/N_iter;
        NMSE_LASSO(i_q)   = NMSE_LASSO(i_q) + norm(h_LASSO - h)^2/norm(h)^2/N_iter;

        NMSE_MM(i_q)   = NMSE_MM(i_q) + norm(h_MM - h)^2/norm(h)^2/N_iter;
        NMSE_QWF(i_q)   = NMSE_QWF(i_q) + norm(h_QWF - h)^2/norm(h)^2/N_iter;
        NMSE_WF(i_q)   = NMSE_WF(i_q) + norm(h_WF - h)^2/norm(h)^2/N_iter;
        NMSE_rand(i_q)   = NMSE_rand(i_q) + norm(h_rand - h)^2/norm(h)^2/N_iter;
        NMSE_TopQ(i_q)   = NMSE_TopQ(i_q) + norm(h_TopQ - h)^2/norm(h)^2/N_iter;
    end
end


figure;hold on ;grid on; box on;
p2 = plot(Q_list, 10*log10(NMSE_AMP),  '->', 'Linewidth', 1.2, 'Color','b');
p3 = plot(Q_list, 10*log10(NMSE_OMP),  '-d', 'Linewidth', 1.2, 'color',[237,177,32]/256);
p1 = plot(Q_list, 10*log10(NMSE_LS),   '--', 'Linewidth', 1.2, 'color', 'k');
% p4 = plot(Q_list, 10*log10(NMSE_LASSO),  '-d', 'Linewidth', 1.2, 'color', [0.3010 0.7450 0.9330]);
p5 = plot(Q_list, 10*log10(NMSE_rand),  '-x', 'Linewidth', 1.2, 'color', [126,47,142]/256);
p6 = plot(Q_list, 10*log10(NMSE_TopQ),  '-*', 'Linewidth', 1.2, 'color', [162,20,47]/256);
p7 = plot(Q_list, 10*log10(NMSE_MM),  '-s', 'Linewidth', 1.2, 'color', [119,172,48]/256);
p8 = plot(Q_list, 10*log10(NMSE_QWF),  '-o', 'Linewidth', 1.2, 'color', [217,83,25]/256);
p9 = plot(Q_list, 10*log10(NMSE_WF),  '--+', 'Linewidth', 1.2, 'color', [0,114,189]/256);

xlabel('Pilot length $Q$','Interpreter','latex','FontSize',14);
ylabel('NMSE [dB]','Interpreter','latex','FontSize',14);
legend([p1,p3,p2,p5,p6,p7,p8,p9], {'LS', 'OMP + $\mathbf{W}^{{\rm Top}Q}$','VAMP + $\mathbf{W}^{{\rm Rand}}$',  ...
     'MMSE + $\mathbf{W}^{{\rm Rand}}$', 'MMSE + $\mathbf{W}^{{\rm Top}Q}$', ...
    'MMSE + $\mathbf{W}^{{\rm MM}}$ (ours)','MMSE + $\mathbf{W}^{{\rm IF}}$ (ours)','MMSE + $\mathbf{W}^{{\rm WF}}$ (ideal)'}, ...
    'FontSize',12,'Interpreter','latex');
xlim([min(Q_list), max(Q_list)])
ylim([-25, 5])


