#! /usr/bin/env python
# -*- coding: utf-8 -*-
'''
EEISP: identify gene pairs that are codependent and mutually exclusive from single-cell RNA-seq data.
Copyright(c)  Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
All rights reserved.
'''

import sys
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import binom
import argparse
import multiprocessing as mp
import gc

def argwrapper(args):
    return args[0](*args[1:])

def calcCDI_eachrow(i, ProbArray, CountArray, ngene, ncell):
    x = CountArray[i + 1:]
    p = ProbArray[i + 1:]
    prob = binom.sf(x - 1, ncell, p)
    val = np.where(prob <= 0, -10000000.0, -(np.log10(prob)))
    array = np.pad(val, (i + 1, 0), mode='constant', constant_values=0)

    del ProbArray
    del CountArray
    return array

def calcEEI_eachrow(i, ProbArray_i, CountArray_i, ProbArray_T_i, CountArray_T_i, ngene, ncell):
    x1 = CountArray_T_i[i + 1:]
    p1 = ProbArray_i[i + 1:]
    prob1 = binom.sf(x1 - 1, ncell, p1)

    x2 = CountArray_i[i + 1:]
    p2 = ProbArray_T_i[i + 1:]
    prob2 = binom.sf(x2 - 1, ncell, p2)

    val = np.where((prob1 <= 0) | (prob2 <= 0), -10000000.0, (-(np.log10(prob1)) + (-(np.log10(prob2)))) / 2)
    array = np.pad(val, (i + 1, 0), mode='constant', constant_values=0)
    return array

def genMatrix_MultiProcess(Prob_joint, Count_joint, MatType, ngene, ncell, *, ncore=4):
    p = mp.get_context('spawn').Pool(ncore)
    func_args = []

    for i in range(0, ngene):
        if MatType == "CDI":
            func_args.append((calcCDI_eachrow, i, Prob_joint[i], Count_joint[i], ngene, ncell))
        elif MatType == "EEI":
            func_args.append((calcEEI_eachrow, i, Prob_joint[i], Count_joint[i], Prob_joint.T[i], Count_joint.T[i], ngene, ncell))
        else:
            print("Error: illegal MatType for genMatrix_MultiProcess.")
            sys.exit()

    results = p.map(argwrapper, func_args)
    p.close()

    Matrix = np.array(results)
    Matrix = Matrix + Matrix.T - np.diag(np.diag(Matrix))

    return Matrix

def count_sum_nonzeroMat(i, is_nonzeroMat):
    return np.sum(is_nonzeroMat[i,:] * is_nonzeroMat, axis=1)

def generate_CDImatrix(A, args):
    ngene = A.shape[0]
    ncell = A.shape[1]
    is_nonzeroMat = A > 0
    p_nonzero = np.sum(is_nonzeroMat, axis=1) / ncell

    print("Calculating CDI...")
    Prob_joint = np.array(p_nonzero * p_nonzero[:, np.newaxis], dtype='float32')
    Count_joint = np.zeros((ngene,ngene), dtype='uint16')

    p = mp.get_context('spawn').Pool(args.threads)
    func_args = []

    for i in range(ngene):
        func_args.append((count_sum_nonzeroMat, i, is_nonzeroMat))

    Count_joint = p.map(argwrapper, func_args)
    p.close()

    CDI = genMatrix_MultiProcess(Prob_joint, Count_joint, "CDI", ngene, ncell, ncore=args.threads)

    return CDI

def count_sum_nonzeroMatnotA(i, is_nonzeroMat, notA):
#    return np.sum(np.logical_and(is_nonzeroMat[i,:], notA), axis=1)
    return np.sum(is_nonzeroMat[i,:] * notA, axis=1)

def generate_EEImatrix(A, args):
    ngene = A.shape[0]
    ncell = A.shape[1]
    is_nonzeroMat = A > 0
    p_nonzero = np.sum(is_nonzeroMat, axis=1) / ncell
    p_zero = np.sum(A == 0, axis=1) / ncell

    print("Calculating EEI...")
    Prob_joint = np.array(p_nonzero * p_zero[:, np.newaxis], dtype='float32')
    notA = np.logical_not(A)
    Count_excl = np.zeros((ngene,ngene), dtype='uint16')

    p = mp.get_context('spawn').Pool(args.threads)
    func_args = []

    for i in range(ngene):
        func_args.append((count_sum_nonzeroMatnotA, i, is_nonzeroMat, notA))

    Count_excl = np.array(p.map(argwrapper, func_args), dtype='uint16')
    p.close()

    del notA
    EEI = genMatrix_MultiProcess(Prob_joint, Count_excl, "EEI", ngene, ncell, ncore=args.threads)
    return EEI


def calc_degree(Matrix, thre, ngene, filename, output, genenames):
    df = pd.DataFrame(Matrix)
    degree = np.sum(df > thre).tolist()
    df = df[df > thre]
    df = df.stack().reset_index()
    df.columns = ["i", "j", "val"]
    df = df[df["i"] < df["j"]]
    df["gene_i"] = genenames[df["i"]]
    df["gene_j"] = genenames[df["j"]]
    df = df.reindex(columns=["i", "j", "gene_i", "gene_j", "val"])
    df = df.sort_values(["val", "i"], ascending=[False, True])

    # Create a histogram
    plt.figure(figsize=(10, 6))
    plt.hist(df["val"], bins=50, color='blue', alpha=0.7, edgecolor='black')
    plt.title('Distribution of ' + output)
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.yscale('log')
    plt.grid(True)
    plt.savefig(output + "_" + filename + ".pdf")

    data_file = output + "_" + filename + ".txt"
    print ("output degree data in " + data_file)
    print ("number of gene pairs over threshold (>" + str(thre) + "): " + str(df.shape[0]))
    df.to_csv(data_file, sep="\t", header=False, index=False)
    return degree

def calc_degree_dist(degree, filename, args):
    max_value = max(degree)
    min_value = min(degree)
    value_range = max_value - min_value
    print(f"max degree:{max_value:.3f} min degree:{min_value:.3f} value_width={value_range:.3f}")

    freq = [[a, degree.count(a)] for a in range(min_value + 1, max_value + 1) if degree.count(a) > 0]

    df = pd.DataFrame(freq, columns=["Degree", "The number of genes"])
    log_df = np.log(df).rename(columns={"Degree": "Log_Degree", "The number of genes": "Log_The number of genes"})

    merge = pd.concat([log_df, df], axis=1)
    merge.to_csv(f"{args.output}_{filename}_degree_distribution.tsv", sep="\t")

def get_nonzero_matrix(input_data):
    A = np.array(input_data)
    A = A[np.any(A > 0, axis=1)]

    ncell_exp = np.sum(input_data > 0, axis=1)
    df = pd.DataFrame(ncell_exp[ncell_exp>0])

    genenames = df.index
    del df
    return A, genenames


def main():
    parser = argparse.ArgumentParser(prog='eeisp')
    parser.add_argument("matrix", help="Input matrix", type=str)
    parser.add_argument("output", help="Output prefix", type=str)
    parser.add_argument("--threCDI", help="Threshold for CDI (default: 10.0)", type=float, default=10)
    parser.add_argument("--threEEI", help="Threshold for EEI (default: 5.0)", type=float, default=5)
    parser.add_argument("--tsv", help="Specify when the input file is tab-delimited (.tsv)", action="store_true")
    parser.add_argument("--CDIonly", help="Calculate CDI only", action="store_true")
    parser.add_argument("--EEIonly", help="Calculate EEI only", action="store_true")
    parser.add_argument("-p", "--threads", help="number of threads (default: 2)", type=int, default=2)
    parser.add_argument("-v", "--version", action='version', version='%(prog)s 0.6.0')

    args = parser.parse_args()
    print(args)

    startt = time.time()

    if (args.tsv):
        input_data = pd.read_csv(args.matrix, index_col=0, sep="\t")
    else:
        input_data = pd.read_csv(args.matrix, index_col=0, sep=",")

    print("number of cells: ", input_data.shape[1])
    print("number of genes: ", input_data.shape[0])

    A, genenames = get_nonzero_matrix(input_data)
    del input_data
    gc.collect()

    ngene = A.shape[0]
    ncell = A.shape[1]
    print("number of nonzero genes: ", ngene)
    print ("-----------------------------------------------")

    if not (args.EEIonly):
        cdi = generate_CDImatrix(A, args)
        degree_cdi = calc_degree(cdi, args.threCDI, ngene, "CDI_score_data_thre" + str(args.threCDI), args.output, genenames)
        calc_degree_dist(degree_cdi, "CDI", args)
        del cdi, degree_cdi
        gc.collect()
        print("Finish to calculate CDI!")

    if not (args.CDIonly):
        eei = generate_EEImatrix(A, args)

        del A

        degree_eei = calc_degree(eei, args.threEEI, ngene, "EEI_score_data_thre" + str(args.threEEI), args.output, genenames)
        calc_degree_dist(degree_eei, "EEI", args)
        del eei, degree_eei
        gc.collect()

        print("Finish to calculate EEI!")

    elapsed_time = time.time() - startt
    print("Elapsed_time:{0}".format(elapsed_time) + "[sec]")
    print("*************************************************************")


if __name__ == "__main__":
    main()
