function W = GPR(kernel, Q, N, P, sigma, mode)
% kernel: Channel covariance matrix 
% Q: Number of pilots 
% N: Number of antennas 
% P: Total power allocated to all pilots
% sigma: noise power 
% mode: GPR mode
    if strcmp(mode, 'WF')
        W = GPR_Digital_WF(kernel, Q, N, P, sigma);
    elseif strcmp(mode, 'QWF')
        W = GPR_Digital_QWF(kernel, Q, N, P, sigma);
    elseif strcmp(mode, 'AP')
        W = GPR_Analog_Projection(kernel, Q, N, P, sigma);
    elseif strcmp(mode, 'MM')
        W = GPR_Analog_MM(kernel, Q, N, P, sigma);
    end

end


function W = GPR_Digital_WF(kernel, Q, N, P, sigma)
    [V, D] = eig(kernel);
    d = diag(D);
    [~, I] = sort(d, 'descend');
    
    % water filling
    if Q <= N
        gamma = (d(I(1:Q)) + 1e-9)/sigma^2;
        p = zeros(Q, 1);
    %     levelh = P/Q + sum(1./gamma)/Q;      % high water level
        levelh = P;      % high water level
        levell = 0;      % low water level
        level  = (levelh + levell) / 2;   % water level
        T = 100;
        % Binary search
        for t = 1:T
            p = level - 1./gamma;
            p(p < 0) = 0;
            if sum(p) == P
                break;
            elseif sum(p) < P
                levell = level;
                level = (levell + levelh) / 2;
            elseif sum(p) > P
                levelh = level;
                level = (levell + levelh) / 2;     
            end
        end
        W = V(:, I(1:Q)) .* sqrt(p.');
    else
        gamma = (d(I(1:N)) + 1e-9)/sigma^2;
        p = zeros(N, 1);
        levelh = P;      % high water level
        levell = 0;      % low water level
        level  = (levelh + levell) / 2;   % water level
        T = 100;
        % Binary search
        for t = 1:T
            p = level - 1./gamma;
            p(p < 0) = 0;
            if sum(p) == P
                break;
            elseif sum(p) < P
                levell = level;
                level = (levell + levelh) / 2;
            elseif sum(p) > P
                levelh = level;
                level = (levell + levelh) / 2;     
            end
        end
        WN = V(:, I(1:N)) .* sqrt(p.');
        W = [WN, zeros(N, Q-N)];
    end
    
end

function W = GPR_Digital_Greedy(kernel, Q, N, P, sigma)

    W = zeros(N, Q);
    [V, D] = eig(kernel);
    D = diag(D);
    [~, I] = max(D);
    
    W(:, 1) = V(:, I); 
    kernel_t = kernel;
    for t = 1:Q-1
          kernel_t = kernel_t - kernel_t*W(:, t)*W(:, t)'*kernel_t / (W(:, t)'*kernel_t*W(:, t) + sigma^2);
          [V, D] = eig(kernel_t);
          D = diag(D);
          [~, I] = max(D);
          W(:, t + 1) = V(:, I); 
    end
    W = W * sqrt(P/Q); 
end

function W = GPR_Digital_QWF(kernel, Q, N, P, sigma)
    %   R_t = R_{t - 1} - R_{t-1}*wt*wt'*R_{t-1}/(wt'*R_{t-1}*wt + sigma^2);
    W = zeros(N, Q);
    [V, D] = eig(kernel);
    D = diag(D);
    for t = 1:Q
        % find the largest eigenvalue 
        [~, I] = max(D);
        W(:,t) = V(:, I);
        % update the eigenvalue 
        D(I) = D(I) * sigma^2 / (D(I) + sigma^2);
    end
    W = W * sqrt(P/Q); 
end

function W = GPR_Analog_Projection(kernel, Q, N, P, sigma)
    W = zeros(N, Q);
    kernel_t = kernel;
    [V, D] = eig(kernel_t);
    [~, I] = max(diag(D));
    W(:, 1) = exp(1j*angle(V(:, I)))/sqrt(N);   
    for t = 1:Q-1
        kernel_t = kernel_t - kernel_t*W(:, t)*W(:, t)'*kernel_t / (W(:, t)'*kernel_t*W(:, t) + sigma^2);
        [V, D] = eig(kernel_t);
        [~, I] = max(diag(D));
        W(:, t + 1) = exp(1j*angle(V(:, I)))/sqrt(N); 
    end
    W = W * sqrt(P/Q);
end

function W = GPR_Analog_MM(kernel, Q, N, P, sigma)
    W = zeros(N, Q);
    kernel_t = kernel;
    % max wHAx -> min -wHAw -> min wH(aI - A)w
    [~, D] = eig(kernel_t);
    [lambda_t, ~] = max(diag(D));
    R_t = lambda_t*eye(N)  - kernel_t;
    N_iter = 100;
    W(:, 1) = MMAlgorithm(R_t, zeros(N, 1), exp(1j*2*pi*rand(N,1)), N_iter, 1e-5)/sqrt(N); 
    for t = 1:Q-1
        % Update Covariance
        kernel_t = kernel_t - kernel_t*W(:, t)*W(:, t)'*kernel_t / (W(:, t)'*kernel_t*W(:, t) + sigma^2);
        % Carry out MM
        [~, D] = eig(kernel_t);
        [lambda_t, ~] = max(diag(D));
        R_t = lambda_t*eye(N)-kernel_t;
        N_iter = 100;
        W(:, t+1) = MMAlgorithm(R_t, zeros(N, 1), exp(1j*2*pi*rand(N,1)), N_iter, 1e-5)/sqrt(N);  
    end
    W = W * sqrt(P/Q);
end