function [at_BS, at_RIS, idx_Nt, idx_Nr1, idx_Nr2] = binary_train(H_br, H_ru, Nt, Nr1, Nr2, SNR, cb_BS, cb_RIS, overhead)
    max_pilot = 4*max(size(cb_BS, 3), size(cb_RIS, 3));
%     if overhead > max_pilot
%         SNR = SNR + 10*log10(overhead/max_pilot);
%     end
    idx_Nt = 1;
    idx_Nr1 = 1;
    idx_Nr2 = 1;
    power = zeros(max_pilot, 1);
    col = 1;
    row = 1;
    for loop = 1:min(overhead, max_pilot)
        switch col
            case 1
                power(loop) = abs(H_ru*diag(cb_RIS(:, 1, min(row, size(cb_RIS, 3))))*H_br*cb_BS(:, 1, min(row, size(cb_BS, 3))))^2;
                col = col + 1;
            case 2
                power(loop) = abs(H_ru*diag(cb_RIS(:, 2, min(row, size(cb_RIS, 3))))*H_br*cb_BS(:, 1, min(row, size(cb_BS, 3))))^2;
                col = col + 1;
            case 3
                power(loop) = abs(H_ru*diag(cb_RIS(:, 1, min(row, size(cb_RIS, 3))))*H_br*cb_BS(:, 2, min(row, size(cb_BS, 3))))^2;
                col = col + 1;
            case 4
                power(loop) = abs(H_ru*diag(cb_RIS(:, 2, min(row, size(cb_RIS, 3))))*H_br*cb_BS(:, 2, min(row, size(cb_BS, 3))))^2;
                col = 1;
                row = row + 1;
        end
    end
    power = awgn(power, SNR);
    power = reshape(power, [4, max_pilot/4]);
    [~, idx] = max(power);
    idx = dec2bin(idx - 1, 2);
    idx_RIS = zeros(length(idx), 1);
    idx_BS = zeros(length(idx), 1);
    for loop = 1:length(idx)
        if idx(loop, 1) == '1'
            idx_BS(loop) = 1;
        end
        if idx(loop, 2) == '1'
            idx_RIS(loop) = 1;
        end
    end
    if length(idx_BS) > log2(Nt)
        idx_BS = idx_BS(1:log2(Nt));
    end
    if length(idx_RIS) > log2(Nr1*Nr2)
        idx_RIS = idx_RIS(1:log2(Nr1*Nr2));
    end
    
    theta_list = -1+1/Nt:2/Nt:1-1/Nt;
    phi_list1 = -1+1/Nr1:2/Nr1:1-1/Nr1;
    phi_list2 = -1+1/Nr2:2/Nr2:1-1/Nr2;
    if overhead < max_pilot
        overhead_eff = floor(overhead/4);
        idx_BS = idx_BS(1:min(overhead_eff, log2(Nt)));
        idx_RIS = idx_RIS(1:min(overhead_eff, log2(Nr2*Nr1)));
        angle_idx_BS = 1;
        angle_idx_RIS = 1;
        for loop = 1:length(idx_BS)
            if idx_BS(loop)
                angle_idx_BS = angle_idx_BS + 2^(log2(Nt) - loop);
            end
        end
        at_BS = angle_idx_BS:angle_idx_BS + 2^(max(log2(Nt) - length(idx_BS) - 1, 0)) - 1;
        for loop = 1:length(idx_RIS)
            if idx_RIS(loop)
                angle_idx_RIS = angle_idx_RIS + 2^(log2(Nr2*Nr1) - loop); 
            end
        end
        at_RIS = angle_idx_RIS:angle_idx_RIS + 2^(log2(Nr1*Nr2) - length(idx_RIS) - 1) - 1;
    else
        angle_idx_BS = 1;
        angle_idx_RIS = 1;
        for loop = 1:length(idx_BS)
            if idx_BS(loop)
                angle_idx_BS = angle_idx_BS + 2^(length(idx_BS) - loop);
            end
        end
        for loop = 1:length(idx_RIS)
            if idx_RIS(loop)
                angle_idx_RIS = angle_idx_RIS + 2^(length(idx_RIS) - loop); 
            end
        end
        idx_Nt = angle_idx_BS;
        idx_Nr1 = ceil(angle_idx_RIS/Nr2);
        idx_Nr2 = mod(angle_idx_RIS - 1, Nr2) + 1; 
        at_BS = conj(array_response(Nt, 1, asin(theta_list(angle_idx_BS)), 0));
        at_RIS = conj(array_response(Nr1, Nr2, asin(phi_list1(ceil(angle_idx_RIS/Nr2))), asin(phi_list2(mod(angle_idx_RIS - 1, Nr2) + 1))));
    end
end













