close all;
clear, clc;
load_parameters;
if Nr1 < 8 || Nr2 < 8
    msg = "The number of elements of RIS at each dimension should not be less than 8!";
    error(msg);
end

%% Parameters
SNR_dB = 10;
SNR = 10.^(SNR_dB/10);
overhead_list = [4:8:100];
iter_max = 5;
save_flag = 1;

%% Exhaustive codebook
regenerate_flag = 1;
file_path_BS = ['./codebook/codebook_Exhaustive_BS_', num2str(Nt), '_0402.mat'];
file_path_RIS = ['./codebook/codebook_Exhaustive_RIS_', num2str(Nr1), '_', num2str(Nr2), '_0402.mat'];
if regenerate_flag && ~exist(file_path_BS, 'file')
    exhaustive_codebook_BS = exhaustive_codebook(Nt, 1);
    save(file_path_BS, 'exhaustive_codebook_BS');
else
    load(file_path_BS);
end
if regenerate_flag && ~exist(file_path_RIS, 'file')
    exhaustive_codebook_RIS = exhaustive_codebook(Nr1, Nr2);
    save(file_path_RIS, 'exhaustive_codebook_RIS');
else
    load(file_path_RIS);
end

%% binary hierarchical codebook
regenerate_flag = 1;
file_path_BS = ['./codebook/codebook_binary_BS_', num2str(Nt), '_0402.mat'];
file_path_RIS = ['./codebook/codebook_binary_RIS_', num2str(Nr1), '_', num2str(Nr2), '_0402.mat'];
bitflag = "ideal";
if regenerate_flag && ~exist (file_path_BS, 'file')
    binary_codebook_BS = binary_hierarchical_codebook(Nt, 1, overrate_Nt, 1, At, 1, "BS", bitflag);
    save(file_path_BS, 'binary_codebook_BS');
else
    load(file_path_BS);
end
if regenerate_flag && ~exist(file_path_RIS, 'file')
    binary_codebook_RIS = binary_hierarchical_codebook(Nr1, Nr2, overrate_Nr1, overrate_Nr2, Ar1, Ar2, "RIS", bitflag);
    save(file_path_RIS, 'binary_codebook_RIS');
else
    load(file_path_RIS);
end

%% Hamming codebook
file_path_BS = ['./codebook/codebook_Ohamming_BS_', num2str(Nt), '_', num2str(Nr1), '_', num2str(Nr2), '_0402.mat'];
file_path_RIS = ['./codebook/codebook_Ohamming_RIS_', num2str(Nt), '_', num2str(Nr1), '_', num2str(Nr2), '_0402.mat'];
regenerate_flag = 1;
bitflag = "ideal";
if regenerate_flag && ~exist(file_path_BS, 'file')
    [Ohamming_codebook_BS, OG_BS, OH_BS] = Ohamming_codebook(Nt, 1, Nr1, Nr2, overrate_Nt, 1, At, 1, "BS", bitflag);
    save(file_path_BS, 'Ohamming_codebook_BS', 'OG_BS', 'OH_BS');
else
    load(file_path_BS);
end
if regenerate_flag && ~exist(file_path_RIS, 'file')
    [Ohamming_codebook_RIS, OG_RIS, OH_RIS] = Ohamming_codebook(Nt, 1, Nr1, Nr2, overrate_Nr1, overrate_Nr2, Ar1, Ar2, "RIS", bitflag);
    save(file_path_RIS, 'Ohamming_codebook_RIS', 'OG_RIS', 'OH_RIS');
else
    load(file_path_RIS);
end
Oc_BS = c_generate(OH_BS);
Oc_RIS = c_generate(OH_RIS);

%% initialization
Rate0 = zeros(1, length(overhead_list));
Rate1 = zeros(1, length(overhead_list));
Rate2 = zeros(1, length(overhead_list));
Rate3 = zeros(1, length(overhead_list));

t0 = clock;

%% iteration
for iter = 1:iter_max
    [H_br, h_b, h_r, theta_opt, phi_opt11, phi_opt12, idx_Nt_opt, idx_Nr1_opt1, idx_Nr2_opt1] = channel_BR(Nt, Nr1, Nr2, 1);
    H_br = channelnorm_BR(H_br, Nt, Nr1, Nr2);
    [H_ru, phi_opt21, phi_opt22, idx_Nr1_opt2, idx_Nr2_opt2] = channel_RU(Nr1, Nr2, 1);
    H_ru = channelnorm_RU(H_ru, Nr1, Nr2);
    idx_Nr1_opt = mod(idx_Nr1_opt1 + idx_Nr1_opt2 - 2, Nr1) + 1;
    idx_Nr2_opt = mod(idx_Nr2_opt1 + idx_Nr2_opt2 - 2, Nr2) + 1;

    for sidx = 1:length(overhead_list)
        overhead = overhead_list(sidx);
        %% Exhaustive
        [at_BS, at_RIS] = exhaustive_train(H_br, H_ru, SNR_dB, exhaustive_codebook_BS, exhaustive_codebook_RIS, overhead);
        array_gain = abs(H_ru*diag(at_RIS)*H_br*at_BS)^2;
        Rate0(sidx) = Rate0(sidx) + real(log2(1 + SNR*array_gain));

        %% Binary
        [at_BS, at_RIS] = binary_train(H_br, H_ru, Nt, Nr1, Nr2, SNR_dB, binary_codebook_BS, binary_codebook_RIS, overhead);
        if length(at_BS) < Nt
            g = g_generate(Nt, 1, overrate_Nt, 1, at_BS);
            at_BS = GS_ideal_window(g, 100, At, 0.3, 0.7);
        end
        if length(at_RIS) < Nr1*Nr2
            at_RIS1 = ceil(at_RIS(1)/Nr2):ceil(at_RIS(end)/Nr2);
            at_RIS2 = mod(at_RIS(1) - 1, Nr2) + 1:mod(at_RIS(end) - 1, Nr2) + 1;
            g = g_generate(Nr1, 1, overrate_Nr1, 1, at_RIS1);
            v1 = GS_RIS_mod(g, 100, Ar1, 0.3, 0.7, "ideal");
            g = g_generate(Nr2, 1, overrate_Nr2, 1, at_RIS2);
            v2 = GS_RIS_mod(g, 100, Ar2, 0.3, 0.7, "ideal");
            at_RIS = kron(v1, v2);
        end

        array_gain = abs(H_ru*diag(at_RIS)*H_br*at_BS)^2;
        Rate1(sidx) = Rate1(sidx) + real(log2(1 + SNR*array_gain));

        %% Hamming
        [at_BS, at_RIS] = hamming_train(H_br, H_ru, Nt, Nr1, Nr2, SNR_dB, Ohamming_codebook_BS, Ohamming_codebook_RIS, OH_BS, OH_RIS, Oc_BS, Oc_RIS, overhead);
        if length(at_BS) < Nt
            g = g_generate(Nt, 1, overrate_Nt, 1, at_BS);
            at_BS = GS_ideal_window(g, 100, At, 0.3, 0.7);
        end
        if length(at_RIS) < Nr1*Nr2
            at_RIS1 = ceil(at_RIS(1)/Nr2):ceil(at_RIS(end)/Nr2);
            at_RIS2 = mod(at_RIS(1) - 1, Nr2) + 1:mod(at_RIS(end) - 1, Nr2) + 1;
            g = g_generate(Nr1, 1, overrate_Nr1, 1, at_RIS1);
            v1 = GS_RIS_mod(g, 100, Ar1, 0.3, 0.7, "ideal");
            g = g_generate(Nr2, 1, overrate_Nr2, 1, at_RIS2);
            v2 = GS_RIS_mod(g, 100, Ar2, 0.3, 0.7, "ideal");
            at_RIS = kron(v1, v2);
        end        
        array_gain = abs(H_ru*diag(at_RIS)*H_br*at_BS)^2;
        Rate2(sidx) = Rate2(sidx) + real(log2(1 + SNR*array_gain));         
        
        %% Hamming-2bit
        [at_BS, at_RIS] = hamming_train_2bit(H_br, H_ru, Nt, Nr1, Nr2, SNR_dB, Ohamming_codebook_BS, Ohamming_codebook_RIS, OH_BS, OH_RIS, Oc_BS, Oc_RIS, overhead, m1, m2);
        if length(at_BS) < Nt
            g = g_generate(Nt, 1, overrate_Nt, 1, at_BS);
            at_BS = GS_ideal_window(g, 100, At, 0.3, 0.7);
        end
        if length(at_RIS) < Nr1*Nr2
            at_RIS1 = ceil(at_RIS(1)/Nr2):ceil(at_RIS(end)/Nr2);
            at_RIS2 = mod(at_RIS(1) - 1, Nr2) + 1:mod(at_RIS(end) - 1, Nr2) + 1;
            g = g_generate(Nr1, 1, overrate_Nr1, 1, at_RIS1);
            v1 = GS_RIS_mod(g, 100, Ar1, 0.3, 0.7, "ideal");
            g = g_generate(Nr2, 1, overrate_Nr2, 1, at_RIS2);
            v2 = GS_RIS_mod(g, 100, Ar2, 0.3, 0.7, "ideal");
            at_RIS = kron(v1, v2);
        end        
        array_gain = abs(H_ru*diag(at_RIS)*H_br*at_BS)^2;
        Rate3(sidx) = Rate3(sidx) + real(log2(1 + SNR*array_gain)); 

        fprintf('overhead = %4d [%d/%d] | iteration:[%d/%d] | run %.4f s\n', overhead, sidx, length(overhead_list), iter, iter_max, etime(clock, t0));
     end
end
Rate0 = real(Rate0./iter_max);
Rate1 = real(Rate1./iter_max);
Rate2 = real(Rate2./iter_max);
Rate3 = real(Rate3./iter_max);

if save_flag
    save(['./data/Rate_pilot_', num2str(Nt), '_', num2str(Nr1), '_', num2str(Nr2), '_0402.mat'], 'overhead_list', 'Rate0', 'Rate1', 'Rate2', 'Rate3');
end

%% Figure
[all_themes, all_colors] = GetColors();
x = 1:1:length(overhead_list);
figure;
hold on;
plot(overhead_list(x), Rate0(x), '--', 'Color', all_colors(1, :), 'LineWidth', 1.5, 'MarkerFaceColor', 'w');
plot(overhead_list(x), Rate1(x), '-<', 'Color', all_colors(2, :), 'LineWidth', 1.5, 'MarkerFaceColor', 'w');
plot(overhead_list(x), Rate2(x), '->', 'Color', all_colors(3, :), 'LineWidth', 1.5, 'MarkerFaceColor', 'w');

grid on;
box on;
legend('Exhaustive beam training', 'Binary hierarchical beam training', 'Proposed coded beam training');
xlabel('SNR (dB)');
ylabel('Achievable rate (bps/s/Hz)');














