clear; close all; clc; 

% Downlink MIMO WMMSE precoding 

Nt = 4;     % # Data streams is equal to this value. 
Nr = 3;     
K  = 3; 
S  = min([Nt, Nr]); 

Pt      = 2;     % Max. Total Tx power 
sigma2  = 1; 
alpha   = ones(1, K); 

% Step 1: Initialize Vk
Vks = cell(K, 1);   % Precoder
Uks = cell(K, 1);   % Combiner
Wks = cell(K, 1);   % Weights 
Hks = cell(K, 1);   % Channels
Eks = cell(K, 1); 
Rks = cell(K, 1); 

rng(0); 

% Initialization 
for kk = 1:K
    Hks{kk} = (randn([Nr, Nt])+1i*randn([Nr, Nt]))/sqrt(2); 
    Vk = (randn([Nt, S])+1i*randn([Nt, S]))/sqrt(2); 
    
    Vk = Vk / norm(Vk, 'fro') * sqrt(Pt/K);  
    Vks{kk} = Vk; 
end

Niter = 50; 
first = true; 

wRate_arr = zeros(Niter, 1); 

for iter = 1:Niter

    CVmat   = zeros(Nt); 

    for m = 1:K
        Hm = Hks{m}; 
        Vm = Vks{m};
        CVmat   = CVmat+(Vm*Vm'); 
    end

    % Update Uk 
    for kk = 1:K
        Hk = Hks{kk}; Vk = Vks{kk};
        Ck = Hk*CVmat*Hk' + sigma2*eye(Nr); 

        Uk = Ck\(Hk*Vk);   % Uk correct 
        Uks{kk} = Uk; 

        Eks{kk} = eye(S) - Vk'*Hk'*Uk - Uk'*Hk*Vk + Uk'*(Ck)*Uk; 
    end
    
    % Update Wk
    for kk = 1:K
        Rk = sigma2*eye(Nr); 
        for ii = 1:K
            if ii ~= kk
                Rk = Rk + Hks{kk}*Vks{ii}*Vks{ii}'*Hks{kk}'; 
            end
        end
        Rks{kk} = Rk; 
        Wk      = eye(S) + Vks{kk}'*Hks{kk}'*inv(Rk)*Hks{kk}*Vks{kk}; 
        Wks{kk} = Wk; 
    end

    % Compute the sum-rate here
    wRate = 0; 
    for kk = 1:K
        Rk = Rks{kk}; 
        Hk = Hks{kk}; 
        Vk = Vks{kk}; 
        wRate = wRate + alpha(kk)*real(log(det(eye(Nr)+Rks{kk}\(Hk*(Vk*Vk')*Hk')))); 
    end
    fprintf('Rate = %e bps/Hz\n', wRate); 
    wRate_arr(iter) = wRate; 

    % Update Vk with power allocation mechanism 
    Bmat = zeros(Nt); 
    Amat = zeros(Nt); 

    for kk = 1:K
        Hk = Hks{kk}; 
        Wk = Wks{kk}; 
        Uk = Uks{kk}; 
        Bmat = Bmat + alpha(kk)^2*Hk'*Uk*(Wk*Wk')*Uk'*Hk; 
        Amat = Amat + alpha(kk)*Hk'*Uk*Wk*Uk'*Hk; 
    end
    % Solve equation Tr(Bmat(lambda*I+Amat)^{-2}) = P_T. 
    
    Narr = 200; 
    lambda_arr = logspace(-1.5, 1.5, Narr); 
    val_arr = zeros(size(lambda_arr)); 
    for idx = 1:Narr
        A           = lambda_arr(idx)*eye(Nt)+Amat;
        A           = A^2; 
        val_arr(idx)= real(trace(Bmat/A)) - Pt; 
    end
    [~, midx]   = min(abs(val_arr)); 
    lambda      = lambda_arr(midx); 
    
    lambda_l  = lambda_arr(midx-1); 
    lambda_r  = lambda_arr(midx+1); 


    % Bisection search 
    for bsiter = 1:12
        A = lambda_l * eye(Nt)+Amat; 
        A = A^2; 
        val_l = real(trace(Bmat/A))-Pt; 

        A = lambda_r * eye(Nt)+Amat; 
        A = A^2; 
        val_r = real(trace(Bmat/A))-Pt; 

        lambda_m = (lambda_l + lambda_r)/2; 
        A = lambda_m * eye(Nt)+Amat; 
        A = A^2; 
        val_m = real(trace(Bmat/A))-Pt; 

        if val_m > 0
            lambda_l = lambda_m; 
        else
            lambda_r = lambda_m; 
        end
        
    end
    fprintf('lambda = %e\n', lambda_m); 
    lambda = lambda_m; 
    % Evaluate V_k.
    A = lambda*eye(Nt)+Amat; 
    for kk = 1:K
        Hk = Hks{kk}; 
        Uk = Uks{kk}; 
        Wk = Wks{kk}; 
        Vk = A\(alpha(kk)*Hk'*Uk*Wk); 
        Vks{kk} = Vk; 
    end
    
    Pt_tmp = 0; 
    for kk = 1:K
        Pt_tmp = Pt_tmp + norm(Vks{kk}, 'fro')^2; 
    end
    fprintf('Tmp Pt = %e\n=======================================\n', Pt_tmp); 
    
    
end

%% Visualization 
plot((1:Niter).', wRate_arr); 

xlabel('Iterations'); 
ylabel('Sum Rate'); grid on; 


