/**********************************************************************
* tallyBAM.cpp An interface for R to tally nucleotide counts from a .bam file
* Adapted from the deepSNV::bam2R function by Moritz Gerstung
***********************************************************************/

#include <stdio.h>
#include <string.h>
#include <string>
#include <iostream>
#include <samtools-1.7-compat.h>
#include <map>
#define R_NO_REMAP
#include <R.h>
#include <Rinternals.h>
#include <R_ext/Rdynload.h>

using namespace std;


typedef struct {
  int beg, end, q, s, head_clip;
  int i;
  int* counts;
  map<basic_string<char>, int> nt_idx;
  samfile_t *in;
} nttable_t;

char NUCLEOTIDES[] = {'A','C','G','T'};
unsigned int nnuc = 4;
char SPECIAL[] = {'-', '+'};
unsigned int nspecial = 2;
int N = (nnuc + nspecial)*3;


// callback for bam_fetch()
static int fetch_func(const bam1_t *b, void *data)
{
  bam_plbuf_t *buf = (bam_plbuf_t*)data;
  bam_plbuf_push(b, buf);
  return 0;
}

static int pileup_func_old(uint32_t tid, hts_pos_t pos, int n, const bam_pileup1_t *pl, void *data)
{
  nttable_t *nttable = (nttable_t*)data;
  int i, s;
  int len = nttable->end - nttable->beg;
  if ((int)pos >= nttable->beg && (int)pos < nttable->end){
    int* counts = nttable->counts; // pointing to the beginning of the array
    for (i=0; i<n; i++){
      const bam_pileup1_t *p = pl + i;
      s = bam1_strand(p->b);
      std::cout << "pos: " << pos << " IsDel?" << p->is_del << " InDel:" << p->indel << std::endl;
      if(!(p->is_del)){
        char c;
        if (p->indel == 0 && bam1_qual(p->b)[p->qpos] > nttable->q) {
          c = char(bam_nt16_rev_table[bam1_seqi(bam1_seq(p->b), p->qpos)]);
        }else if(p->indel < 0){
          c = '-';
        }else if(p->indel > 0){
          c = '+';
        }
        basic_string<char> str;
        str += c;
        std::cout << "Counting: " << str << std::endl;
        if( bam1_qual(p->b)[p->qpos] > nttable->q){
          if (((bam1_core_t)p->b->core).l_qseq - p->qpos < nttable->head_clip){
            str += "e";
          }else if (p->qpos < nttable->head_clip) {
            str += "b";
          }
          if(c != '-'){
            counts[ ((int)pos - nttable->beg) * 2 * N + (s * N) + nttable->nt_idx[str] ]++;
          }else{
            //Rprintf("Deletion of length: %i\n", p->indel);
            for(int j = p->indel; j < 0; ++j){
              //Rprintf("Position: %i, j is %i\n", (int)pos - nttable->beg - j, j);
              if(nttable->end > (int)pos - nttable->beg - j){
                //Rprintf("Counting at: %i, j is %i\n", ((int)pos - nttable->beg - j), j);
                counts[ ((int)pos - nttable->beg - j) * 2 * N + (s * N) + nttable->nt_idx[str] ]++;
              }
            }
          }
        }
      }
    }
    nttable->i++;
  }
  return 0;
}

static int pileup_func(uint32_t tid, hts_pos_t pos, int n, const bam_pileup1_t *pl, void *data)
{
  nttable_t *nttable = (nttable_t*)data;
  int i, s;
  int len = nttable->end - nttable->beg;
  if ((int)pos >= nttable->beg && (int)pos < nttable->end){
    int* counts = nttable->counts; // pointing to the beginning of the array
    for (i=0; i<n; i++){
      const bam_pileup1_t *p = pl + i;
      if( bam1_qual(p->b)[p->qpos] > nttable->q){
        s = bam1_strand(p->b); 
        char c;
        basic_string<char> stratum = "";
        basic_string<char> indexLabel = "";
        if (((bam1_core_t)p->b->core).l_qseq - p->qpos < nttable->head_clip){
          stratum += "e";
        }else if (p->qpos < nttable->head_clip) {
          stratum += "b";
        }
        if(p->is_del){
          indexLabel = "-";
          indexLabel += stratum;
          counts[ ((int)pos - nttable->beg) * 2 * N + (s * N) + nttable->nt_idx[indexLabel] ]++;
          continue;
        }
        if(p->indel > 0){
          indexLabel = "+";
          indexLabel += stratum;
          counts[ ((int)pos - nttable->beg) * 2 * N + (s * N) + nttable->nt_idx[indexLabel] ]++;
          indexLabel = "";
        }
        if (bam1_qual(p->b)[p->qpos] > nttable->q) {
          c = char(bam_nt16_rev_table[bam1_seqi(bam1_seq(p->b), p->qpos)]);
          indexLabel += c;
          indexLabel += stratum;
          counts[ ((int)pos - nttable->beg) * 2 * N + (s * N) + nttable->nt_idx[indexLabel] ]++;
        }
      }
    }
    nttable->i++;
  }
  return 0;
}

extern "C" {

int _tallyBAM(char** bamfile, char** ref, int* beg, int* end, int* counts, int* q, int* s, int* head_clip, int* maxdepth, int* verbose)
{
	int c = 0;
	nttable_t nttable;
	nttable.q = *q; //Base quality cutoff
	nttable.s = *s; //Strand (2=both)
	nttable.head_clip = *head_clip;
	nttable.i = 0;
	nttable.counts = counts;

  //BEGIN: Ugly code
  nttable.nt_idx["Ab"] = 0;
  nttable.nt_idx["Cb"] = 1;
  nttable.nt_idx["Gb"] = 2;
  nttable.nt_idx["Tb"] = 3;
  nttable.nt_idx["A"] = 4;
  nttable.nt_idx["C"] = 5;
  nttable.nt_idx["G"] = 6;
  nttable.nt_idx["T"] = 7;
  nttable.nt_idx["Ae"] = 8;
  nttable.nt_idx["Ce"] = 9;
  nttable.nt_idx["Ge"] = 10;
  nttable.nt_idx["Te"] = 11;
  nttable.nt_idx["-b"] = 12;
  nttable.nt_idx["-"] = 13;
  nttable.nt_idx["-e"] = 14;
  nttable.nt_idx["+b"] = 15;
  nttable.nt_idx["+"] = 16;
  nttable.nt_idx["+e"] = 17;
  //END: Ugly code

	/*for (int i=0; i<nnuc; ++i){
	  basic_string<char> tmp;
	  tmp += NUCLEOTIDES[i];
	  nttable.nt_idx[tmp] = 1*nnuc + i;
	  tmp += "b";
	  nttable.nt_idx[tmp] = i;
	  tmp.clear();
	  tmp += NUCLEOTIDES[i];
	  tmp += "e";
	  nttable.nt_idx[tmp] = 2*nnuc + i;
	}

  for (int i=0; i<nspecial; ++i){
	  basic_string<char> tmp;
	  tmp += SPECIAL[i];
	  nttable.nt_idx[tmp] = 3*nnuc + 1*nspecial + i;
	  tmp += "b";
	  nttable.nt_idx[tmp] = 3*nnuc + i;
	  tmp.clear();
	  tmp += SPECIAL[i];
	  tmp += "e";
	  nttable.nt_idx[tmp] = 3*nnuc + 2*nspecial + i;
	}*/

	nttable.beg = *beg -1;
	nttable.end = *end;
	nttable.in = samopen(*bamfile, "rb", 0);
	if (nttable.in == 0) {
		Rf_error("Fail to open BAM file %s\n", *bamfile);
		return 1;
	}

	if (strcmp(*ref, "") == 0) { // if a region is not specified
		sampileup(nttable.in, -1, pileup_func, &nttable);
	}
	else {
		int tid;
		bam_index_t *idx;
		bam_plbuf_t *buf;
		idx = bam_index_load(*bamfile); // load BAM index
		if (idx == 0) {
			Rf_error("BAM indexing file is not available.\n");
			return 1;
		}

		tid = bam_get_tid(nttable.in->header, *ref);
		if (tid < 0) {
			Rf_error("Invalid sequence %s\n", *ref);
			return 1;
		}
		if(*verbose)
			Rprintf("Reading %s, %s:%i-%i\n", *bamfile, *ref, nttable.beg, nttable.end);
    buf = bam_plbuf_init(pileup_func, &nttable); // initialize pileup
		bam_plp_set_maxcnt(buf->iter, *maxdepth);
		bam_fetch(nttable.in->x.bam, idx, tid, nttable.beg, nttable.end, buf, fetch_func);
		bam_plbuf_push(0, buf); // finalize pileup
		bam_index_destroy(idx);
		bam_plbuf_destroy(buf);
	}
	samclose(nttable.in);
	return 0;
}

R_CMethodDef cMethods[] = {
   {"_tallyBAM", (DL_FUNC) &_tallyBAM, 9}
};

void R_init_tallyBAM(DllInfo *info) {
   R_registerRoutines(info, cMethods, NULL, NULL, NULL);
}

} // extern "C"
