#! /usr/bin/env python
# -*- coding: utf-8 -*-

"""
Add gene names to an edge/matrix file using a gene ID to gene name mapping table.

Input:
    A tab-delimited file with at least 5 columns:
        0: i
        1: j
        2: geneid1
        3: geneid2
        4: value

    A tab-delimited gene list file containing at least:
        - one column for gene ID
        - one column for gene name

Behavior:
    - geneid1 and geneid2 are converted to gene names using the mapping table.
    - If an ID is not found in the gene list, the original ID is kept instead of raising an error.

Output:
    A tab-delimited file with columns:
        i, j, geneid1, geneid2, genename1, genename2, val
"""

import pandas as pd
import argparse

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Input matrix", type=str)
    parser.add_argument("output", help="Output file", type=str)
    parser.add_argument("genelist", help="Gene list", type=str)
    parser.add_argument("--i_id", help="column number of gene id (default: 0)", type=int, default=0)
    parser.add_argument("--i_name", help="column number of gene name (default: 1)", type=int, default=1)

    args = parser.parse_args()

    i = args.i_id
    j = args.i_name

    genes = pd.read_csv(args.genelist, sep="\t", header=None)
    gene_map = genes.set_index(i)[j].to_dict()

    input_data = pd.read_csv(args.input, sep="\t", header=None)

    input_data["genename1"] = input_data[2].map(gene_map)
    input_data["genename2"] = input_data[3].map(gene_map)

    # gtf/genelist に存在しない ID は、元の ID をそのまま入れる
    input_data["genename1"] = input_data["genename1"].fillna(input_data[2])
    input_data["genename2"] = input_data["genename2"].fillna(input_data[3])

    input_data.columns = ["i", "j", "geneid1", "geneid2", "val", "genename1", "genename2"]
    input_data = input_data.reindex(columns=["i", "j", "geneid1", "geneid2", "genename1", "genename2", "val"])

    input_data.to_csv(args.output, sep="\t", index=False, header=False)

if __name__ == "__main__":
    main()
