#>>>import
import os
import gc
import sys
import tqdm
import torch
import logging
import subprocess
import numpy as np
import pandas as pd
import xgboost as xgb
from pathlib import Path
import concurrent
import random
from concurrent.futures import ThreadPoolExecutor
from sklearn.utils import shuffle
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc, accuracy_score, f1_score, recall_score, precision_score, matthews_corrcoef

#>>>function
K = java_K
n = java_n
max_workers = threads

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

if 'params_group' not in locals():
    params_group = [
        {
            'beta': torch.nn.Parameter(torch.nn.init.constant_(torch.empty(1), 1)),
            'gamma': torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(1), 0, 1)),
        }
        for _ in range(K)
    ]
global_models_save_list = [0 for _ in range(K)]
global_x_train_all = []
global_y_train_all = []
global_train_metrics_df = pd.DataFrame(columns=['epoch', 'fold', 'sample', 'auc_roc', 'auc_pr', 'ks', 'best_accuracy', 'best_threshold', 'best_recall', 'best_precision', 'best_f1', 'best_mcc'])
global_test_metrics_df = pd.DataFrame(columns=['epoch', 'fold', 'sample', 'auc_roc', 'auc_pr', 'ks', 'best_accuracy', 'best_threshold', 'best_recall', 'best_precision', 'best_f1', 'best_mcc'])

def abs_log10_ver1(p_values, beta, gamma, summary_pvalues):
    p_values = -np.log10(p_values)
    indices = summary_pvalues != -999.0
    p_values = torch.tensor(p_values)
    summary_pvalues = torch.tensor(summary_pvalues)
    coff_0 = torch.sigmoid(beta)
    p_values[indices] = coff_0 * p_values[indices] + (1-coff_0) * summary_pvalues[indices]

    p_values = torch.exp(gamma*p_values)
    p_values = p_values/torch.sum(p_values)
    return p_values

def abs_log10_ver2(p_values, beta, gamma, summary_pvalues, cadd):
    p_values = -np.log10(p_values)
    indices = summary_pvalues != -999.0
    p_values = torch.tensor(p_values)
    summary_pvalues = torch.tensor(summary_pvalues)
    coff_0 = torch.sigmoid(beta)
    p_values[indices] = coff_0 * p_values[indices] + (1-coff_0) * summary_pvalues[indices]

    p_values = cadd*torch.exp(gamma*p_values)
    p_values = p_values/torch.sum(p_values)
    return p_values


def merge_and_normalize(arr1, arr2, dtype=np.float32):
    # 检查输入
    if arr1.shape[0] != arr2.shape[0]:
        raise ValueError("输入数组的第一维必须相同")

    # 合并数组（强制转换为指定类型）
    merged = np.concatenate((arr1, arr2), axis=1).astype(dtype)
    normalized = np.empty_like(merged, dtype=dtype)
    # 计算每列的均值和标准差（忽略NaN）
    means = np.nanmean(merged, axis=0, dtype=dtype)  # 指定计算精度
    stds = np.nanstd(merged, axis=0, dtype=dtype)
    # 处理标准差为0的列
    stds[stds == 0] = 1  # 避免除以0
    # 向量化标准化（NaN自动保留）
    normalized = (merged - means) / stds

    return normalized


def balanced_downsample_indices(labels):
    # 转换为numpy数组
    labels = np.array(labels)
    unique_labels = np.unique(labels)
    # 检查是否为二分类
    if len(unique_labels) != 2:
        raise ValueError("标签必须包含且仅包含2个类别")
    # 获取每个类别的索引
    class_indices = [np.where(labels == label)[0] for label in unique_labels]
    # 确定少数类和多数类
    minority_class, majority_class = sorted(class_indices, key=len)
    # 从多数类中随机采样与少数类相同数量的样本
    downsampled_majority = np.random.choice(
        majority_class,
        size=len(minority_class),
        replace=False
    )
    # 合并少数类和下采样后的多数类索引
    balanced_indices = np.concatenate([minority_class, downsampled_majority])
    # 打乱顺序以避免类别顺序带来的潜在影响
    np.random.shuffle(balanced_indices)
    return balanced_indices.tolist()

def sample_indices0(i, j, probs, select_size, n):
    """抽样函数，为每个(i,j)生成独立可重复的随机种子"""
    rng = np.random.RandomState(seed + i * n + j)
    return rng.choice(len(probs), size=select_size, p=probs, replace=False)

def sample_indices(probabilities_list, K, n, select_size):
    indices_selected_list = []
    with ThreadPoolExecutor(max_workers=threads) as executor:
        for i in range(K):
            probs = probabilities_list[i]
            # 为当前i提交所有j的抽样任务
            futures = [executor.submit(sample_indices0, i, j, probs, select_size, n) for j in range(n)]
            # 按j顺序收集结果
            indices_selected_list.append([future.result().tolist() for future in futures])
    return indices_selected_list

def calculate_metrics(y_true, y_pred_prob):
    y_true = y_true.cpu().numpy() if isinstance(y_true, torch.Tensor) else y_true
    y_pred_prob = y_pred_prob.cpu().detach().numpy() if isinstance(y_pred_prob, torch.Tensor) else y_pred_prob
    # 计算 AUC-ROC
    auc_roc = roc_auc_score(y_true, y_pred_prob)
    # 计算 AUC-PR (Precision-Recall AUC)
    precision, recall, _ = precision_recall_curve(y_true, y_pred_prob)
    auc_pr = auc(recall, precision)
    # 计算 KS (Kolmogorov-Smirnov)
    fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
    ks = max(tpr - fpr)
    # 计算不同阈值下的 Accuracy, Precision, Recall, F1, MCC
    thresholds = np.linspace(0, 1, 101)  # 生成从0到1的100个阈值
    best_accuracy = 0
    best_threshold = 0
    best_recall = 0
    best_precision = 0
    best_f1 = 0
    best_mcc = 0

    # 遍历不同的阈值，找到最佳 Accuracy 和对应的指标
    for threshold in thresholds:
        y_pred = (y_pred_prob >= threshold).astype(int)
        accuracy = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        recall = recall_score(y_true, y_pred, zero_division=0)
        precision = precision_score(y_true, y_pred, zero_division=0)
        mcc = matthews_corrcoef(y_true, y_pred)

        # 更新最佳准确率和对应指标
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold
            best_recall = recall
            best_precision = precision
            best_f1 = f1
            best_mcc = mcc

    return auc_roc, auc_pr, ks, best_accuracy, best_threshold, best_recall, best_precision, best_f1, best_mcc

def parse_direct_genotype(byte_buffer):
    # 将ByteBuffer转为numpy数组（零拷贝）
    byte_array = np.array(byte_buffer, copy=False, dtype=np.uint8)
    # 读取元数据 (前8字节)
    rows = int.from_bytes(byte_array[:4].tobytes(), byteorder='big')
    cols = int.from_bytes(byte_array[4:8].tobytes(), byteorder='big')
    total_elements = rows * cols
    # 计算需要的long数量
    longs_needed = (total_elements + 31) // 32
    bytes_needed = longs_needed * 8
    # 提取数据部分
    data_bytes = byte_array[8:8+bytes_needed]
    # 转换为int64数组（大端序）
    long_array = np.frombuffer(
        data_bytes,
        dtype='>i8',  # 大端序int64
        count=longs_needed
    )
    # 向量化解压
    indices = np.arange(total_elements)
    long_indices = indices // 32
    bit_offsets = (indices % 32) * 2
    # 创建结果数组（使用掩码避免越界）
    valid_mask = long_indices < longs_needed
    values = np.zeros(total_elements, dtype=np.float32)
    values[:] = np.nan  # 默认设为NaN

    # 仅处理有效索引
    extracted_values = (long_array[long_indices[valid_mask]] >> bit_offsets[valid_mask]) & 0x03
    # 将3转换为NaN，其他值保持不变
    extracted_values = np.where(extracted_values == 3, np.nan, extracted_values)
    values[valid_mask] = extracted_values

    return values.reshape(rows, cols)

def parse_direct_binary(byte_buffer):
    # 将ByteBuffer转为numpy数组（零拷贝）
    byte_array = np.array(byte_buffer, copy=False, dtype=np.uint8)
    # 读取元数据 (前8字节)
    size = int.from_bytes(byte_array[:4].tobytes(), byteorder='big')
    # 4-7字节为padding，忽略
    # 计算所需long数量
    longs_needed = (size + 63) // 64  # 每个long存储64个元素
    bytes_needed = longs_needed * 8
    # 提取数据部分（从第8字节开始）
    data_bytes = byte_array[8:8+bytes_needed]
    # 转换为uint64数组（大端序）
    long_array = np.frombuffer(
        data_bytes,
        dtype='>u8',  # 大端序uint64
        count=longs_needed
    )
    # 创建结果数组
    values = np.zeros(size, dtype=np.uint8)
    # 向量化解压每个long
    for i in range(longs_needed):
        start_idx = i * 64
        end_idx = min(start_idx + 64, size)
        num_elements = end_idx - start_idx
        if num_elements > 0:
            # 创建当前块元素的位索引
            bit_indices = np.arange(num_elements, dtype=np.uint64)
            # 提取每个位：(long_value >> bit_index) & 1
            chunk = (long_array[i] >> bit_indices) & 1
            values[start_idx:end_idx] = chunk.astype(np.uint8)

    return values


###>>>sample
from collections import Counter
if down_sample:
    label_counts = Counter(labels)
    min_count = min(label_counts.values())  # 最少的类别样本数
    # 对每个类别进行下采样
    downsampled_indices = []
    for label in label_counts:
        # 获取当前类别的所有样本索引
        label_indices = np.where(np.array(labels) == label)[0]
        # 随机选择 min_count 个样本
        selected = np.random.choice(label_indices, size=min_count, replace=False)
        downsampled_indices.extend(selected)
    # 打乱下采样后的数据
    downsampled_indices = np.random.permutation(downsampled_indices)
    # 重新划分 KFold
    skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=seed)
    downsampled_labels = [labels[int(i)] for i in downsampled_indices]  # 下采样后的标签
    folds = []
    for train_idx, val_idx in skf.split(np.zeros(len(downsampled_labels)), downsampled_labels):
        # 转换为原始数据的索引
        train_indices = downsampled_indices[train_idx].tolist()
        val_indices = downsampled_indices[val_idx].tolist()
        folds.append([train_indices, val_indices])
else:
    skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=seed)
    folds = []
    for train_index, val_index in skf.split(np.zeros(len(labels)), labels):
        folds.append([train_index.tolist(), val_index.tolist()])


#>>>select
if 'probabilities_list' not in locals():
    probabilities_list = []

k = java_k
p_values = np.array(java_p_values)
summary_pvalues = np.array(java_summary_pvalues)
select_size = java_select_size
type = java_function_score_type

if type == "cadd":
    if 'cadd_score' not in locals():
        raise ValueError('cadd_score not defined')
    probabilities = abs_log10_ver2(p_values, params_group[k]["beta"], params_group[k]["gamma"], summary_pvalues, cadd_score)
    probabilities_list.append(probabilities.detach().numpy())
if type == "normal":
    probabilities = abs_log10_ver1(p_values, params_group[k]["beta"], params_group[k]["gamma"], summary_pvalues)
    probabilities_list.append(probabilities.detach().numpy())

#>>>train
if 'params_group' not in locals():
    raise ValueError('params_group not defined')
if 'best_models_auc_roc' not in locals():
    best_models_auc_roc = [0 for _ in range(K)]
if 'models_save_list' not in locals():
    models_save_list = [0 for _ in range(K)]
if 'train_df' not in locals():
    train_df = pd.DataFrame(columns=['epoch', 'fold', 'sample', 'auc_roc', 'auc_pr', 'ks', 'best_accuracy', 'best_threshold', 'best_recall', 'best_precision', 'best_f1', 'best_mcc'])

def process_train(i):
    if ignoreGty:
        train_covars = np.array(java_train_covars[i]).T
        return [train_covars for _ in range(n)]
    elif needCovar:
        train_covars = np.array(java_train_covars[i]).T
        return [merge_and_normalize(parse_direct_genotype(java_X_train[i][j]), train_covars)
                for j in range(n)]
    else:
        return [parse_direct_genotype(java_X_train[i][j]) for j in range(n)]

def process_validate(i):
    if ignoreGty:
        validate_covars = np.array(java_validate_covars[i]).T
        return [validate_covars for _ in range(n)]
    elif needCovar:
        validate_covars = np.array(java_validate_covars[i]).T
        return [merge_and_normalize(parse_direct_genotype(java_X_validate[i][j]), validate_covars)
                for j in range(n)]
    else:
        return [parse_direct_genotype(java_X_validate[i][j]) for j in range(n)]


with ThreadPoolExecutor(max_workers=max_workers) as executor:
    X_train_all = list(executor.map(process_train, range(K)))
    X_validate_all = list(executor.map(process_validate, range(K)))
    y_train_all = list(executor.map(lambda i: parse_direct_binary(java_y_train[i]), range(K)))
    y_validate_all = list(executor.map(lambda i: parse_direct_binary(java_y_validate[i]), range(K)))

# ====================== 训练阶段并行化 ======================
def train_single_model(args):
    """训练单个XGBoost模型"""
    i, j = args
    model = xgb.XGBClassifier(
        max_depth=5,
        learning_rate=0.1,
        n_estimators=100,
        objective='binary:logistic',
        random_state=seed,
        eval_metric='logloss',
        n_jobs=max(1, threads//8)
    )
    model.fit(X_train_all[i][j], y_train_all[i])
    y_pred_proba = model.predict_proba(X_validate_all[i][j])[:, 1]
    return j, model, y_pred_proba

loss_list = []
auc_roc_list = []
models = []
criterion = torch.nn.BCELoss(reduction="mean")
currentI = java_current_i

for i in range(K):
    temp_models = [None] * n
    y_pre_all = [None] * n
    # 创建任务列表: 当前fold i的所有j模型
    tasks = [(i, j) for j in range(n)]
    # 使用线程池并行训练当前fold的所有模型
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交所有任务并收集结果
        future_to_index = {executor.submit(train_single_model, task): task for task in tasks}
        # 处理完成的任务
        for future in concurrent.futures.as_completed(future_to_index):
            j, model, y_pred = future.result()
            temp_models[j] = model
            y_pre_all[j] = y_pred

    # 计算当前fold的平均预测和指标
    y_pre_all_mean = np.mean(y_pre_all, axis=0)
    loss = criterion(torch.tensor(y_pre_all_mean, dtype=torch.float), torch.tensor(y_validate_all[i], dtype=torch.float))
    auc_roc, auc_pr, ks, best_accuracy, best_threshold, best_recall, best_precision, best_f1, best_mcc = calculate_metrics(y_validate_all[i], y_pre_all_mean)

    # 保存结果
    train_df.loc[len(train_df)] = [int(currentI+1), int(i+1), 'mean', auc_roc, auc_pr, ks, best_accuracy, best_threshold, best_recall, best_precision, best_f1, best_mcc]

    models.append(temp_models)
    loss_list.append(loss)
    auc_roc_list.append(auc_roc)

change_index = []
for j in range(K):
    if auc_roc_list[j] > best_models_auc_roc[j]:
        best_models_auc_roc[j] = auc_roc_list[j]
        models_save_list[j] = models[j]
        change_index.append(j)
        if not ignoreGty:
            l2_lambda = 0.01
            learning_rate = 0.1
            regularization = l2_lambda * (torch.norm(params_group[j]['beta'], p=2) ** 2 + torch.norm(params_group[j]['gamma'], p=2) ** 2)
            indices = list(set(list(np.array(indices_selected_list[j]).reshape(-1))))
            log_probs = torch.tensor(probabilities_list[j])
            selected_log_probs = log_probs[indices]
            log_prob_sum = (selected_log_probs.sum() - len(indices) * log_probs.mean()).requires_grad_(True)

            policy_loss = loss.detach() * (torch.log(log_prob_sum + regularization))
            # 显式计算梯度
            policy_grad_beta = torch.autograd.grad(policy_loss, params_group[j]['beta'], retain_graph=True)[0]
            policy_grad_gamma = torch.autograd.grad(policy_loss, params_group[j]['gamma'], retain_graph=True)[0]
            # 手动更新参数
            with torch.no_grad():
                params_group[j]['beta'] -= learning_rate * policy_grad_beta
                params_group[j]['gamma'] -= learning_rate * policy_grad_gamma

#>>>test
if ignoreGty:
    test_covars = np.array(java_test_covars)
    X_test_all = [[test_covars.T for _ in range(n)] for _ in range(K)]
elif needCovar:
    test_covars = np.array(java_test_covars)
    X_test_all = [[merge_and_normalize(parse_direct_genotype(java_X_test[i][j]), test_covars.T) for j in range(n)] for i in range(K)]
else:
    X_test_all = [[parse_direct_genotype(java_X_test[i][j]) for j in range(n)] for i in range(K)]

y_test_all = parse_direct_binary(java_y_test)

# 测试阶段并行化
def predict_single_model(args):
    """预测单个模型"""
    i, j, model = args
    return model.predict_proba(X_test_all[i][j])[:, 1]

# 初始化测试结果DataFrame
if 'test_df' not in locals():
    test_df = pd.DataFrame(columns=['epoch', 'fold', 'sample', 'auc_roc', 'auc_pr', 'ks', 'best_accuracy', 'best_threshold', 'best_recall', 'best_precision', 'best_f1', 'best_mcc'])
if 'models_save_list' not in locals():
    raise ValueError('models_save_list not defined')

# 创建所有模型的预测任务 (K × n 个任务)
test_tasks = []
for i in range(K):
    for j in range(n):
        test_tasks.append((i, j, models_save_list[i][j]))

# 并行执行所有预测任务
y_pre_all = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    # 批量提交任务
    futures = [executor.submit(predict_single_model, task) for task in test_tasks]
    # 按完成顺序收集结果
    for future in concurrent.futures.as_completed(futures):
        y_pre_all.append(future.result())

# 计算平均预测和最终指标
y_pre_all = np.array(y_pre_all)
y_pre_all_mean = np.mean(y_pre_all, axis=0)
auc_roc, auc_pr, ks, best_accuracy, best_threshold, best_recall, best_precision, best_f1, best_mcc = calculate_metrics(y_test_all, y_pre_all_mean)

# 保存测试结果
test_df.loc[len(test_df)] = [int(currentI+1), 'mean', 'mean', auc_roc, auc_pr, ks, best_accuracy, best_threshold, best_recall, best_precision, best_f1, best_mcc]

# 清理环境
X_validate_all = None
X_test_all = None
y_validate_all = None
y_test_all = None
models = None
_ = gc.collect()

###>>>get_importance
from collections import defaultdict

def train_random_model(args):
    """训练单个XGBoost模型"""
    count, i, j, global_x_train_all, global_y_train_all = args
    y = shuffle(global_y_train_all[i], random_state=seed+count*100+i*10+j)
    model = xgb.XGBClassifier(
        max_depth=5,
        learning_rate=0.1,
        n_estimators=100,
        objective='binary:logistic',
        random_state=seed,
        eval_metric='logloss',
        n_jobs=max(1, threads//8)
    )
    model.fit(global_x_train_all[i][j], y)
    return j, model

def process_model(args):
    i, j, models_list, all_snp_list = args
    local_gain = defaultdict(float)

    booster = models_list[i][j].get_booster()
    snp_names = all_snp_list[i][j]
    importance_dict = booster.get_score(importance_type='gain')

    for feat, gain in importance_dict.items():
        feat_idx = int(feat[1:])
        snp = snp_names[feat_idx]
        local_gain[snp] = gain

    return local_gain

def calculate_importance_scores(K, n, models_list, all_snp_list, max_workers):
    """计算重要性分数函数"""
    total_gain = defaultdict(float)
    total_usage = defaultdict(int)

    # 准备所有任务参数
    tasks = [(i, j, models_list, all_snp_list) for i in range(K) for j in range(n)]
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 并行处理所有模型
        results = executor.map(process_model, tasks)
        # 合并结果
        for local_gain in results:
            for snp, gain in local_gain.items():
                total_gain[snp] += gain
                total_usage[snp] += 1


    # 计算平均gain分数
    avg_gains = {}
    for snp in total_gain:
        if total_usage[snp] > 0:
            avg_gains[snp] = total_gain[snp] / total_usage[snp]
            # avg_gains[snp] = total_gain[snp] / (K * n)
        else:
            avg_gains[snp] = 0

    return avg_gains, total_usage

# 初始化
all_snp_list = np.array(java_all_snp_list)

# 先计算真实的重要性分数
real_gains, total_usage = calculate_importance_scores(K, n, global_models_save_list, all_snp_list, max_workers)

# 置换处理
max_gain_per_permutation = []
if permutation != 0:
    for count in tqdm.tqdm(range(permutation), desc="Model Permutation Test Processing Progress", unit="count"):
        random_models = []
        for i in range(K):
            temp_models = [None] * n
            # 创建任务列表: 当前fold i的所有j模型
            tasks = [(count, i, j, global_x_train_all, global_y_train_all) for j in range(n)]
            # 使用线程池并行训练当前fold的所有模型
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # 提交所有任务并收集结果
                future_to_index = {executor.submit(train_random_model, task): task for task in tasks}
                # 处理完成的任务
                for future in concurrent.futures.as_completed(future_to_index):
                    j, model = future.result()
                    temp_models[j] = model
            random_models.append(temp_models)

        # 计算当前置换的重要性分数
        perm_gains, _ = calculate_importance_scores(K, n, random_models, all_snp_list, max_workers)

        # 计算当前置换中所有SNP的最大gain分数
        if perm_gains:
            current_max_gain = max(perm_gains.values())
        else:
            current_max_gain = 0

        max_gain_per_permutation.append(current_max_gain)

# 导出相关信息
export_p = {}
export_gains = {}
export_usages = {}

if permutation != 0:
    for snp in real_gains:
        true_score = real_gains[snp]

        # 统计真实分数大于每次置换最大值的次数
        count_extreme = 0
        for max_gain in max_gain_per_permutation:
            if true_score <= max_gain:
                count_extreme += 1

        # 计算p值：真实分数大于置换最大值的次数比例，使用+1进行平滑处理
        p_value = (count_extreme + 1) / (permutation + 1)
        export_p[snp] = p_value

# 导出gain分数和使用次数
for snp in real_gains:
    export_gains[snp] = real_gains[snp]
    export_usages[snp] = total_usage[snp]

###>>>predict
if ignoreGty:
    predict_covars = np.array(java_predict_covars)
    X_predict_all = [[predict_covars.T for j in range(n)] for i in range(K)]
elif needCovar:
    predict_covars = np.array(java_predict_covars)
    X_predict_all = [[merge_and_normalize(parse_direct_genotype(java_X_predict[i][j]), predict_covars.T) for j in range(n)] for i in range(K)]
else:
    X_predict_all = [[parse_direct_genotype(java_X_predict[i][j]) for j in range(n)] for i in range(K)]

y_predict_all = []
for i in range(K):
    for j in range(n):
        y_pred_proba = global_models_save_list[i][j].predict_proba(X_predict_all[i][j])[:, 1]
        y_predict_all.append(y_pred_proba.tolist())


#>>>save
if 'global_models_save_list' not in locals():
    raise ValueError('global_models_save_list not defined')

model_save_path = Path(java_model_save_path)
train_save_path = Path(java_train_save_path)

global_train_metrics_df.to_csv(train_save_path, sep='\t', index=False)
torch.save(global_models_save_list, model_save_path)

test_save_path = Path(java_test_save_path)
global_test_metrics_df.to_csv(test_save_path, sep='\t', index=False)
