function opt_result = qWMMSE_IFdiv(RAQR_config, mu_arr, gamma_arr, delta_arr, A_LO_arr, Hmks, Cqms, C_sig, alpha, Pmax, sigma2_e, NiterMax, eps, lr_ALO, method, vbs)
% This function performs WMMSE precoding for each of the M sub-bands. 
% 
% 

if ~exist("vbs", "var") || isempty(vbs)
    vbs = 2; 
end

% Optimization decision variables 
[M, K]      = size(Hmks); 
[Nr, Nt]    = size(Hmks{1,1}); 
S           = min([Nt, Nr]);  % Max number of available data streams

Vmks = cell(M, K);   % Tx Precoder
Umks = cell(M, K);   % Rx Combiner
Wmks = cell(M, K);   % Weights 

Rmks = cell(M, K);   % Rydberg-Rx noise covariance for (m,k)-th UE. 
Rms  = cell(M, 1); 
Ams  = cell(M, 1); 
Rates = zeros(M,K); 

% Initialize precoder V_mk
for mm = 1:M
    for kk = 1:K
        Vmk = (randn([Nt, S])+1i*randn([Nt, S]))/sqrt(2); 
        
        Vmk = Vmk / norm(Vmk, 'fro') * sqrt(Pmax(mm, kk));  
        Vmks{mm, kk} = Vmk; 
    end
end


wRate_trace     = zeros(NiterMax, 1); 
gq_trace        = zeros(NiterMax, 2); 
ALO_trace       = zeros(NiterMax, 2); 
gqder_trace     = zeros(NiterMax, 2); 
Palloc_trace    = zeros(NiterMax, M, K); 

for iter = 1:NiterMax
    
    % Quantum-analysis 
    [gqs, Jq] = getMultiband_gq(RAQR_config, M, mu_arr, gamma_arr, delta_arr, A_LO_arr); 

    if vbs >= 2
        fprintf('Quantum transconductance gq (mS): \n'); 
        disp(abs(gqs)*1e3); 
        fprintf('LO amplitudes: \n'); 
        disp(A_LO_arr); 
    end
    
    Cwq = sigma2_e*eye(Nr); 
    for mm = 1:M
        Cwq = Cwq + (gqs(mm)*C_sig(mm))^2*Cqms{mm}; 
    end
    
    % Combiner design 
    for mm = 1:M
        Rm = Cwq; 
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
            HV = Hmk*Vmk; 
            Rm = Rm + (gqs(mm)*C_sig(mm))^2*(HV*HV'); 
        end
        Rms{mm} = Rm; 
    end
    
    for mm = 1:M
        RtotInv = inv(Rms{mm}); 
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm, kk}; 
            HV = Hmk*Vmk; 
            Umk         = (gqs(mm)*C_sig(mm))*(RtotInv*HV);
            Umks{mm,kk} =  Umk; 
        end
    end
    
    % Weight design 
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
            HV  = Hmk*Vmk; 
            Rmk = Rms{mm}-(gqs(mm)*C_sig(mm))^2*(HV*HV'); 
            Wmk = eye(S) + (gqs(mm)*C_sig(mm))^2*(HV)'*(Rmk\HV); 
            Wmks{mm,kk} = Wmk; 
            Rmks{mm,kk} = Rmk; 
        end
    end
    
    % Compute the weighted sum-rate here
    wRate = 0; 
    for mm = 1:M
        for kk = 1:K
            Rmk = Rmks{mm,kk}; 
            Hmk = Hmks{mm,kk}; 
            Vmk = Vmks{mm,kk}; 
            HV  = Hmk*Vmk; 
            rate            = real(log2(det(eye(Nr)+Rmk \ ((gqs(mm)*C_sig(mm))^2*(HV*HV')) ))); 
            Rates(mm,kk)    = rate; 
            wRate           = wRate + alpha(mm,kk)*rate; 
        end
    end
    if vbs >= 2
        fprintf('Rate = %e bps/Hz\n', wRate); 
        disp(Rates); 
    end
    wRate_trace(iter) = wRate; 


    % Precoder design
    for mm = 1:M
        Am = zeros(Nr); 
        for kk = 1:K
            Umk = Umks{mm,kk}; 
            Wmk = Wmks{mm,kk}; 
            Am = Am + alpha(mm,kk)*Umk*Wmk*Umk'; 
        end
        Ams{mm} = Am; 
    end
    for mm = 1:M
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Umk = Umks{mm,kk}; Wmk = Wmks{mm,kk}; Am = Ams{mm}; 
            Tmp = (gqs(mm)*C_sig(mm))^2*(Hmk'*Am*Hmk); 

            Nl          = 45; 
            lambda_arr  = logspace(-5, 4, Nl)*real(trace(Tmp))/S;
            P_arr       = zeros(1, Nl); 

            for idx = 1:Nl
                Vmk = (lambda_arr(idx)*eye(Nt) + Tmp)\(alpha(mm,kk)*gqs(mm)*C_sig(mm)*(Hmk'*Umk*Wmk)); 
                P_arr(idx) = norm(Vmk, 'fro')^2; 
            end

            while P_arr(end) >= Pmax(mm,kk)
                lambda_arr(end) = lambda_arr(end)*2; 
                Vmk = (lambda_arr(end)*eye(Nt) + Tmp)\(alpha(mm,kk)*gqs(mm)*C_sig(mm)*(Hmk'*Umk*Wmk)); 
                P_arr(end) = norm(Vmk, 'fro')^2; 
            end

            if P_arr(1) < Pmax(mm,kk)
                % Very small lambda 
                Vmk = ( lambda_arr(1)*eye(Nt) + Tmp)\(alpha(mm,kk)*gqs(mm)*C_sig(mm)*(Hmk'*Umk*Wmk)); 

            else
                [~, lam_idx] = find(P_arr < Pmax(mm,kk), 1, 'first'); 
                lambda_r = lambda_arr(lam_idx); 
                % Bisection method 
                lambda_l = lambda_arr(lam_idx-1); 
               
                while true
                    lambda_m = (lambda_l+lambda_r)/2; 
                    Vmk_m = (lambda_m*eye(Nt) + Tmp)\(alpha(mm,kk)*gqs(mm)*C_sig(mm)*(Hmk'*Umk*Wmk)); 
                    P_m = norm(Vmk_m, 'fro')^2; 
                    
                    if abs(P_m - Pmax(mm,kk)) >= eps*Pmax(mm,kk)
                        if P_m - Pmax(mm,kk) > 0
                            lambda_l = lambda_m; 
                        else
                            lambda_r = lambda_m; 
                        end
                    else
                        Vmk = Vmk_m; 
                        break;
                    end
                end
            end
            
            Vmks{mm,kk}         = Vmk; 
            Palloc_trace(iter, mm, kk) = norm(Vmk, 'fro')^2; 
        end
    end
    
    % Optimize quantum A_LO parameters 
    [f_old, Emks]   = getWMMSEConvexTargetFunction(gqs, Cqms, C_sig, sigma2_e, alpha, Umks, Vmks, Wmks, Hmks); 

    if lr_ALO > 0
        dfq_da          = getQuantumTargetFunctionDerivatives(gqs, Jq, C_sig, alpha, Hmks, Umks, Vmks, Wmks, Cqms); 
    
        if method == "Armijo-Goldstein"
            while true
                A_new           = A_LO_arr - lr_ALO*dfq_da.'; 
                for mm = 1:M
                    if A_new(mm) <= 1e-4
                        A_new(mm) = 1e-4;
                    end
                end
                [gqs_new, ~]    = getMultiband_gq(RAQR_config, M, mu_arr, gamma_arr, delta_arr, A_new); 
                [f_new, ~]      = getWMMSEConvexTargetFunction(gqs_new, Cqms, C_sig, sigma2_e, alpha, Umks, Vmks, Wmks, Hmks); 
        
                if f_new > f_old
                    lr_ALO = lr_ALO*0.7;
                    continue; 
                else
                    A_LO_arr = A_new; 
                    break; 
                end 
            end
    
        elseif method == "GradientDescent" 
            A_LO_arr = A_LO_arr - lr_ALO*dfq_da.'; 
    
        else
            error('Method %s not defined.', method);
        end
        
    
        gqder_trace(iter, :)    = dfq_da;       % The derivative target function w.r.t. A_LO 
        gq_trace(iter, :)       = gqs.'; 
        ALO_trace(iter, :)      = A_LO_arr.'; 

    end
    

    if vbs >= 1
        fprintf('===================== ITER %d END ===================== \n\n', iter); 
    end
end

opt_result.ALO_trace    = ALO_trace; 
opt_result.gq_trace     = gq_trace; 
opt_result.gqder_trace  = gqder_trace; 
opt_result.Palloc_trace = Palloc_trace; 
opt_result.wRate_trace  = wRate_trace; 

opt_result.Vmks         = Vmks;     % Linear precoder
opt_result.Umks         = Umks;     % Linear combiner 
opt_result.Emks         = Emks;     % MMSE-detection MSE matrix
opt_result.Rmks         = Rmks;     % noise+interference covariance 

end


function [gqs, Jq] = getMultiband_gq(RAQR_config, M, mu_arr, gamma_arr, delta_arr, A_LO_arr)
    % Physical constants 
    hbar        = 6.626e-34 / (2*pi); 
    a0          = 5.2918e-11; 
    e           = 1.6e-19; 
    epsilon_0   = 8.85e-12;
    mu_12       = 4.5022*e*a0; 
    c0          = getPhysicalConstant('LightSpeed');
    kp          = 2*pi/(RAQR_config.lambda_p); 
    
    % Quantum Analysis. the (gq) definition is different here. We always assume a positive value. 
    [Tq, dTq, rho_bar]  = getDCTransferCoeffs(RAQR_config, M, mu_arr, gamma_arr, delta_arr, A_LO_arr); 
    
    P_probe             = 29.8e-6; 
    alpha_0             = -(kp*RAQR_config.N0*(mu_12^2))/(epsilon_0*hbar*RAQR_config.Omega_p)*imag(rho_bar(2,1)); 
    probeTransmission   = exp(-2*alpha_0*RAQR_config.d); 
    Iph0                = (RAQR_config.eta*P_probe)/(hbar*(c0*kp))*e*probeTransmission; 
    factor2             = (2*kp*RAQR_config.N0*(mu_12^2))/(epsilon_0*hbar*RAQR_config.Omega_p); 
    gq_PreFactor        = Iph0*factor2; 
    
    gqs = gq_PreFactor*mu_arr/(2*hbar).*imag(Tq);
    Jq  = gq_PreFactor*(mu_arr/(2*hbar)).*imag(dTq).*(mu_arr.'/(2*hbar)); 
    Jq  = Jq + gq_PreFactor*(factor2*RAQR_config.d)*(mu_arr/(2*hbar)).*(imag(Tq)*imag(Tq).').*(mu_arr.'/(2*hbar)); 

end

function [loss, Emks] = getWMMSEConvexTargetFunction(gqs, Cqms, C_sig, sigma2_e, alpha, Umks, Vmks, Wmks, Hmks)
% This function returns the sum of alpha_mk Tr(W_mk E_mk)
[M, K] = size(Wmks); 
[~, S] = size(Vmks{1,1});
[Nr, ~] = size(Hmks{1,1}); 


Cwq = sigma2_e*eye(Nr); 
for mm = 1:M
    Cwq = Cwq + (gqs(mm)*C_sig(mm))^2*Cqms{mm}; 
end
Rms = cell(M, 1); 
for mm = 1:M
    Rm = Cwq; 
    for kk = 1:K
        Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
        HV = Hmk*Vmk; 
        Rm = Rm + (gqs(mm)*C_sig(mm))^2*(HV*HV'); 
    end
    Rms{mm} = Rm; 
end

% Re-compute E_mk with known Rm
Emks = cell(M, K); 
for mm = 1:M
    for kk = 1:K
        Umk = Umks{mm,kk}; Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
        HV = Hmk*Vmk; 
        Emks{mm,kk} = eye(S) - (gqs(mm)*C_sig(mm))*(Umk'*HV) - (gqs(mm)*C_sig(mm))*(HV'*Umk) + Umk'*Rms{mm}*Umk; 
    end
end

% Compute the target function
loss = 0; 
for mm = 1:M
    for kk = 1:K
        Wmk = Wmks{mm,kk}; 
        tmp = trace(Wmk*Emks{mm,kk}) - log(det(Wmk));
        loss = loss + alpha(mm,kk)*real(tmp); 
    end
end


end


function dfq_da = getQuantumTargetFunctionDerivatives(gqs, Jq, C_sig, alpha, Hmks, Umks, Vmks, Wmks, Cqms)

    [M, K] = size(Hmks); 
    dfq_dgq = zeros(1,M);
    Rprime  = cell(M,M); 
    [Nr, ~] = size(Hmks{1,1}); 

    Rtmp    = cell(M, 1); 
    for mm = 1:M
        Rt = zeros(Nr); 
        for kk = 1:K
            Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; 
            HV = Hmk*Vmk; 
            Rt = Rt + HV*HV'; 
        end
        Rtmp{mm} = Rt; 
    end

    for ell = 1:M
        for mm = 1:M
            Rprime_tmp = Cqms{ell}; 
            if mm == ell
                Rprime_tmp = Rprime_tmp + Rtmp{ell}; 
            end
            Rprime{mm,ell} = 2*gqs(ell)*C_sig(ell)^2*Rprime_tmp; 
        end
    end

    for ell = 1:M
        der_tmp = 0;

        for mm = 1:M
            for kk = 1:K
                Wmk = Wmks{mm,kk}; Hmk = Hmks{mm,kk}; Vmk = Vmks{mm,kk}; Umk = Umks{mm,kk}; 
                UpHV = Umk'*Hmk*Vmk; 

                if mm == ell
                    der_tmp = der_tmp + alpha(mm,kk)*trace(Wmk*(-C_sig(mm)) * (UpHV+UpHV') ); % d(Emk)/dgq_ell
                end

                der_tmp = der_tmp + trace(Wmk*Umk'*Rprime{mm,ell}*Umk); 
            end
        end
        dfq_dgq(ell) = real(der_tmp); 
    end
    dfq_da = dfq_dgq*Jq; 

end


