function opt_result = cWMMSE(Hmks, alpha, Pmax, sigma2_e, NiterMax, eps, vbs)

if ~exist("vbs", "var") || isempty(vbs)
    vbs = 2; 
end

% Optimization decision variables. 
[M, K]      = size(Hmks); 
[Nr, Nt]    = size(Hmks{1,1}); 
S           = min([Nt, Nr]);    % Max number of available data streams

Vmks = cell(M, K);   % Tx Precoder
Umks = cell(M, K);   % Rx Combiner
Wmks = cell(M, K);   % Weights 
Rmks = cell(M, K);   % Classical Rx noise covariance for (m,k)-th UE. 

Rates = zeros(M,K); 

% Initialize precoder V_mk
for mm = 1:M
    for kk = 1:K
        Vmk = (randn([Nt, S])+1i*randn([Nt, S]))/sqrt(2); 
        
        Vmk = Vmk / norm(Vmk, 'fro') * sqrt(Pmax(mm, kk));  
        Vmks{mm, kk} = Vmk; 
    end
end

wRate_trace     = zeros(NiterMax, 1); 
Palloc_trace    = zeros(NiterMax, M, K); 

for iter = 1:NiterMax
    
    % Combiner design 
    R = sigma2_e*eye(Nr); 
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
            HV = Hmk*Vmk; 
            R = R + (HV*HV'); 
        end
    end
    RtotInv = inv(R); 
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm, kk}; 
            HV          = Hmk*Vmk; 
            Umk         = RtotInv*HV;
            Umks{mm,kk} =  Umk; 
        end
    end
    
    % Weight design 
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
            HV  = Hmk*Vmk; 
            Rmk = R - (HV*HV'); 
            Wmk = eye(S) + (HV)'*(Rmk\HV); 
            Wmks{mm,kk} = Wmk; 
            Rmks{mm,kk} = Rmk; 
        end
    end
    
    % Compute the weighted sum-rate here
    wRate = 0; 
    for mm = 1:M
        for kk = 1:K
            Rmk = Rmks{mm,kk}; 
            Hmk = Hmks{mm,kk}; 
            Vmk = Vmks{mm,kk}; 
            HV  = Hmk*Vmk; 
            rate            = real(log2(det(eye(Nr)+Rmk \ (HV*HV') ))); 
            Rates(mm,kk)    = rate; 
            wRate           = wRate + alpha(mm,kk)*rate; 
        end
    end


    if vbs >= 2
        fprintf('Rate = %e bps/Hz\n', wRate); 
        disp(Rates); 
    end
    wRate_trace(iter) = wRate; 
    

    % Precoder design
    A = zeros(Nr); 
    for mm = 1:M
        for kk = 1:K
            Umk = Umks{mm,kk}; 
            Wmk = Wmks{mm,kk}; 
            A = A + alpha(mm,kk)*Umk*Wmk*Umk'; 
        end
    end
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Umk = Umks{mm,kk}; Wmk = Wmks{mm,kk}; 
            Tmp = Hmk'*A*Hmk; 

            Nl          = 45; 
            lambda_arr  = logspace(-5, 2, Nl)*real(trace(Tmp))/S;
            P_arr       = zeros(1, Nl); 

            for idx = 1:Nl
                Vmk = (lambda_arr(idx)*eye(Nt) + Tmp)\(alpha(mm,kk)*(Hmk'*Umk*Wmk)); 
                P_arr(idx) = norm(Vmk, 'fro')^2; 
            end

            assert(P_arr(end) < Pmax(mm,kk)); 
            if P_arr(1) < Pmax(mm,kk)
                % Very small lambda 
                Vmk = ( lambda_arr(1)*eye(Nt) + (Hmk'*A*Hmk))\(alpha(mm,kk)*(Hmk'*Umk*Wmk)); 

            else
                [~, lam_idx] = find(P_arr < Pmax(mm,kk), 1, 'first'); 
                lambda_r = lambda_arr(lam_idx); 
                % Bisection method 
                lambda_l = lambda_arr(lam_idx-1); 
                
                while true
                    lambda_m = (lambda_l+lambda_r)/2; 
                    Vmk_m   = (lambda_m*eye(Nt) + (Hmk'*A*Hmk))\(alpha(mm,kk)*(Hmk'*Umk*Wmk)); 
                    P_m     = norm(Vmk_m, 'fro')^2; 
                    
                    if abs(P_m - Pmax(mm,kk)) >= eps*Pmax(mm,kk)
                        if P_m - Pmax(mm,kk) > 0
                            lambda_l = lambda_m; 
                        else
                            lambda_r = lambda_m; 
                        end
                    else
                        Vmk = Vmk_m; 
                        break;
                    end
                end
            end
            
            Vmks{mm,kk}         = Vmk; 
            Palloc_trace(iter, mm, kk) = norm(Vmk, 'fro')^2; 
        end
    end

    [f_old, Emks] = getcWMMSEConvexTargetFunction(sigma2_e, alpha, Umks, Vmks, Wmks, Hmks); % We can delete this step. 
    
    % Early-stopping 
    if iter >= 2 && abs(wRate-wRate_trace(iter-1))/wRate <= eps
        break;
    end 

    if vbs >= 1
        fprintf('===================== ITER %d END ===================== \n\n', iter); 
    end
end

opt_result.Palloc_trace = Palloc_trace(1:iter, :, :); 
opt_result.wRate_trace  = wRate_trace(1:iter); 

opt_result.Vmks         = Vmks;     % Linear precoder
opt_result.Umks         = Umks;     % Linear combiner 
opt_result.Emks         = Emks;     % MMSE-detection MSE matrix
opt_result.Rmks         = Rmks;     % noise+interference covariance 
opt_result.Rates        = Rates; 
end


function [loss, Emks] = getcWMMSEConvexTargetFunction(sigma2_e, alpha, Umks, Vmks, Wmks, Hmks)
% This function returns the sum of alpha_mk Tr(W_mk E_mk)
[M, K]  = size(Wmks); 
[~, S]  = size(Vmks{1,1});
[Nr, ~] = size(Hmks{1,1}); 

% Re-compute E_mk 
R = sigma2_e*eye(Nr); 
for mm = 1:M
    for kk = 1:K
        Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
        HV = Hmk*Vmk; 
        R = R + (HV*HV'); 
    end
end
Emks = cell(M, K); 
for mm = 1:M
    for kk = 1:K
        Umk = Umks{mm,kk}; Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
        HV = Hmk*Vmk; 
        Emks{mm,kk} = eye(S) - (Umk'*HV) - (HV'*Umk) + Umk'*R*Umk; 
    end
end

% Compute the target function
loss = 0; 
for mm = 1:M
    for kk = 1:K
        Wmk = Wmks{mm,kk}; 
        tmp = trace(Wmk*Emks{mm,kk}) - log(det(Wmk));
        loss = loss + alpha(mm,kk)*real(tmp); 
    end
end


end



