#!/usr/bin/env python

from __future__ import print_function, absolute_import, division, unicode_literals

import sys

if sys.version_info[0:2] <= (2,6):
    # Note: to support python2.6, we would need to handle missing `argparse` and convert '{}'.format() -> '{0}'.format()
    print("bravo requires python2.7 or newer.  Please re-run it as `python2.7 bravo ...` or `python3 bravo ...`.", file=sys.stderr)
    sys.exit(1)

import argparse
import json
import os
import re
import time
import signal

# Specialized imports for python 2/3 cross-compatibility
try:
    from urllib.parse import urlencode # py3
except ImportError:
    from urllib import urlencode # py2

try:
    from urllib.request import Request, urlopen, HTTPError, URLError # py3
except ImportError:
    from urllib2 import Request, urlopen, HTTPError, URLError # py2

try:
    JSONDecodeError = json.JSONDecodeError # py3
except AttributeError:
    JSONDecodeError = ValueError # py2

try:
    basestring # py3
except NameError:
    basestring = str # py2


signal.signal(signal.SIGPIPE, signal.SIG_DFL) # prevent a messy exception when running `bravo ... | head`


argparser = argparse.ArgumentParser(description='Manage Google Oauth 2.0 authentication for Bravo API and do querying from command line.')
commands = argparser.add_subparsers(dest='command', title='Commands')
login_command = commands.add_parser('login', help='Authorize access to Bravo API.')
token_command = commands.add_parser('print-access-token', help='Display access token for Bravo API.')
revoke_command = commands.add_parser('revoke', help='Revoke all access tokens for Bravo API.')
query_region_command = commands.add_parser('query-region', help='Query chromosomal region.')
query_gene_command = commands.add_parser('query-gene', help='Query by gene name or gene identifier.')
query_variant_command = commands.add_parser('query-variant', help='Query variant by variant identifier or by chromosome name and chromosomal position.',
        description='Query variant by identifier CHROM-POS-REF-ALT, or by chromosome name and chromosomal position.')
query_meta_command = commands.add_parser('query-meta', help='Query version and data description.')
annotate_command = commands.add_parser('annotate', help='Annotate input VCF.',
        description='Uncompressed input VCF must be streamed to standard input. Uncompressed output VCF is streamed to standard output. Multi-allelic variant records in input VCF must be split into multiple bi-allelic variant records.')

query_region_command.add_argument('-c', '--chromosome', metavar='name', type=str, required=True, dest='chromosome',
                                  help='Chromosome name.')
query_region_command.add_argument('-s', '--start', metavar= 'base-pair', type=int, required=True, dest='start',
                                  help='Start position.')
query_region_command.add_argument('-e', '--end', metavar='base-pair', type=int, required=True, dest='end',
                                  help='End position.')
query_region_command.add_argument('-o', '--output', required=False, choices=['json', 'vcf'], default='json', dest='format',
                                  help='Output format.')
query_region_command.add_argument('-f', '--filter', metavar='expression', required=False, type=str, dest='filter',
                                  help='Filtering expression.')

query_gene_command.add_argument('-n', '--name', metavar='name', type=str, required=True, dest='gene',
                                help='Gene name or gene identifier.')
query_gene_command.add_argument('-o', '--output', required=False, choices=['json', 'vcf'], default='json', dest='format',
                                help='Output format.')
query_gene_command.add_argument('-f', '--filter', metavar='expression', required=False, type=str, dest='filter',
                                help='Filtering expression.')

query_variant_command.add_argument('-v', '--variant', metavar='chrom-pos-ref-alt/rs#', type=str, dest='variant_id',
                                   help='Variant identifier CHROM-POS-REF-ALT or rs#.')
query_variant_command.add_argument('-c', '--chromosome', metavar='name', type=str, dest='chromosome',
                                   help='Chromosome name.')
query_variant_command.add_argument('-p', '--position', metavar='base-pair', type=int, dest='position',
                                   help='Position.')
query_variant_command.add_argument('-o', '--output', required=False, choices=['json', 'vcf'], default='json', dest='format',
                                   help='Output format.')

annotate_command.add_argument('-f', '--filter', metavar='expression', required=False, type=str, dest='filter',
                              help='Filtering expression.')

USER_HOME = os.path.expanduser("~")
BRAVO_DIR = '.bravo'
BRAVO_CREDSTORE = 'credstore'
BRAVO_API_VERSION = 'v1'
BRAVO_AUTH_API = 'https://bravo.sph.umich.edu/api/{}/auth/auth'.format(BRAVO_API_VERSION)
BRAVO_TOKEN_API = 'https://bravo.sph.umich.edu/api/{}/auth/token'.format(BRAVO_API_VERSION)
BRAVO_REVOKE_API = 'https://bravo.sph.umich.edu/api/{}/auth/revoke'.format(BRAVO_API_VERSION)
BRAVO_IP_API = 'https://bravo.sph.umich.edu/api/{}/auth/ip'.format(BRAVO_API_VERSION)


class BravoException(Exception):
    def __init__(self, message):
        self.message = message

    def __str__(self):
        return self.message


class requests(object):
    """Implements the parts we need of the real `requests` module"""
    @staticmethod
    def get(url, headers=None, params=None):
        headers = headers or {}
        if params: url += '?' + urlencode(params)
        req = Request(url, headers=headers)
        return _requests_response(req)
    @staticmethod
    def post(url, headers=None, data=None):
        headers = headers or {}
        if isinstance(data, dict): data = urlencode(data)
        if data is not None: data = data.encode('utf8')
        req = Request(url, headers=headers, data=data)
        return _requests_response(req)
class _requests_response(object):
    # urlopen error-handling:
    # - if it can't connect (or the connection drops), it raises URLError
    # - 0-100 => BadStatusLine
    # - 101-199 => HTTPError
    # - 2xx => ok
    # - 3xx => HTTPError, except that 301/302/303/307 with valid "Location:" header behave according to the new location
    # - 4xx/5xx => HTTPError
    # For comparison, with <https://github.com/requests/requests>:
    # - if it can't connect (or the connection drops), it raises requests.exceptions.ConnectionError
    # - 0-100 => requests.exceptions.ConnectionError
    # - 101-199 => ok
    # - 2xx => ok
    # - 3xx => ok, except that 301/302/303/307 with valid "Location:" header behave according to the new location
    # - 4xx/5xx => ok, but requests.get(url).raise_for_status() raises requests.exceptions.HTTPError
    # So, when receiving only 2xx/4xx/5xx and valid 301/302/303/307 and using .raise_for_status(), they behave the same.
    def __init__(self, request):
        try:
            response = urlopen(request)
        except HTTPError as exc:
            self.status_code = exc.getcode()
            try:
                self._json = _json_load_str_or_bytes(exc.fp)
            except JSONDecodeError as exc:
                self._json = str(exc)
            self._exception = exc
        except URLError as exc:
            raise BravoException("Failed to connect ")
        else:
            self.status_code = response.getcode()
            self._json = _json_load_str_or_bytes(response.fp)
    def json(self):
        return self._json
if sys.version_info[0:2] == (2,7) or sys.version_info[0:2] >= (3,6):
    _json_load_str_or_bytes = json.load
else:
    def _json_load_str_or_bytes(fp):
        return json.loads(fp.read().decode())



def credstore_exists():
    return os.path.isfile(os.path.join(USER_HOME, BRAVO_DIR, BRAVO_CREDSTORE))


def create_credstore():
    p = os.path.join(USER_HOME, BRAVO_DIR)
    if not os.path.isdir(p):
        os.mkdir(p, 0o700)
    p = os.path.join(p, BRAVO_CREDSTORE)
    if not os.path.isfile(p):
        open(p, 'a').close()
        os.chmod(p, 0o600)


def read_credstore():
    required = ['access_token', 'created', 'token_type', 'revoked']
    fpath = os.path.join(USER_HOME, BRAVO_DIR, BRAVO_CREDSTORE)
    if not os.path.exists(fpath):
        raise BravoException("Please use `./bravo login` to login.  Currently {} does not exist.".format(fpath))
    with open(os.path.join(USER_HOME, BRAVO_DIR, BRAVO_CREDSTORE), 'r') as credstore_file:
        credstore = json.load(credstore_file)
        active_credentials = credstore.get('active', None)
        if active_credentials is None or not isinstance(active_credentials, basestring):
            raise BravoException('Invalid or outdated credentials store. You may need to run login.')
        all_credentials = credstore.get('all', None)
        if all_credentials is None or not isinstance(all_credentials, dict):
            raise BravoException('Invalid or outdated credentials store. You may need to run login.')
        if active_credentials not in all_credentials:
            raise BravoException('Invalid or outdated credentials store. You may need to run login.')
        for ip, credentials in all_credentials.items():
            if not isinstance(credentials, dict):
                raise BravoException('Invalid or outdated credentials store. You may need to run login.')
            if not all(k in credentials for k in required):
                raise BravoException('Invalid access token entry in credentials store. You may need to run login.')
        return credstore


def write_credstore(data):
    path = os.path.join(USER_HOME, BRAVO_DIR, BRAVO_CREDSTORE)
    with open(path, 'w') as credstore:
        json.dump(data, credstore, indent=4)
    os.chmod(path, 0o600)


def login():
    bravo_response = requests.get(BRAVO_IP_API)
    if bravo_response.status_code != 200:
        raise BravoException('Error while obtaining your public IP with Bravo API authentication server.')
    bravo_response_data = bravo_response.json()
    ip = bravo_response_data['ip']
    if not credstore_exists():
        create_credstore()
        credstore = {'active': None, 'all': {}}
    else:
        try:
            credstore = read_credstore()
            credentials = credstore['all'].get(ip, None)
            if credentials is not None and not credentials['revoked']:
                credstore['active'] = ip
                write_credstore(credstore)
                print('You are signed in.')
                return
        except Exception:
            credstore = {'active': None, 'all': {}}
    bravo_response = requests.get(BRAVO_AUTH_API)
    if bravo_response.status_code == 400:   
        raise BravoException(bravo_response.json().get('error', 'Failed to obtain authentication link.'))
    elif bravo_response.status_code != 200:
        raise BravoException('Error while obtaining authentication link from Bravo API authentication server.')  
    bravo_response_data = bravo_response.json()
    auth_url = bravo_response_data['auth_url']
    auth_token = bravo_response_data['auth_token']
    print('Go to the following link in your browser:\n\n{}\n'.format(auth_url))
    print('\nContacting Bravo API for access tokens...')
    sys.stdout.flush()
    while True:
        time.sleep(5)
        bravo_response = requests.post(BRAVO_TOKEN_API, data={'auth_token': auth_token})
        if bravo_response.status_code == 400:
            raise BravoException(bravo_response.json().get('error', 'Failed to obtain authentication link.'))
        elif bravo_response.status_code != 200:
            raise BravoException('Error while obtaining authentication token from Bravo API authentication server.')
        bravo_response_data = bravo_response.json()
        if bravo_response_data['access_token'] is not None:
            break
    credstore['active'] = bravo_response_data['ip']
    credstore['all'][bravo_response_data['ip']] = {
        'access_token': bravo_response_data['access_token'],
        'token_type': bravo_response_data['token_type'],
        'created': int(time.time()),
        'revoked': False
    }
    write_credstore(credstore)
    print('Done.')
    print('You are signed in.')


def print_access_token():
    credstore = read_credstore()
    if credstore['all'][credstore['active']]['revoked']:
        raise BravoException("The current token has been revoked.  Please login with `bravo login`.")
    print(credstore['all'][credstore['active']]['access_token'])


def revoke():
    if not credstore_exists():
        print('No access tokens to revoke.')
        return
    credstore = read_credstore()
    ip_revoked = list()
    for ip, credentials in credstore['all'].items():
        if not credentials['revoked']:
            if not ip_revoked:
                bravo_response = requests.get(BRAVO_REVOKE_API, params={'access_token': credentials['access_token']})
                if bravo_response.status_code == 400:
                    raise BravoException(bravo_response.json().get('error', 'Failed to revoke access.'))
                elif bravo_response.status_code != 200:
                    raise BravoException('Bravo API authentication server is not accessible.')
            ip_revoked.append(ip)
            credentials['revoked'] = True
    if ip_revoked:
        write_credstore(credstore)
        print('Access tokens have been successfully revoked.')
    else:
        print('No access tokens to revoke.')


def _query_nonpaged(url, headers):
    bravo_response = requests.get(url, headers=headers)
    if bravo_response.status_code == 400:
        raise BravoException(bravo_response.json().get('error', 'Failed to query data.'))
    elif bravo_response.status_code != 200:
        raise BravoException('Bravo API server failed with status code {}'.format(bravo_response.status_code))
    return bravo_response.json()


def query_meta():
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/'.format(BRAVO_API_VERSION)
    json.dump(_query_nonpaged(query_url, headers), sys.stdout); print('')


def parse_filter_expressions(filter_str):
    op_map = {'==': 'eq',  '!=': 'ne', '>': 'gt', '<': 'lt', '>=': 'gte', '<=': 'lte'}
    parsed_filter = []
    for expression in filter_str.split('&'):
        tokens = re.split(r'(==|\!=|>|<|>=|<=)', expression)
        if len(tokens) != 3:
            raise BravoException('Bad filter expression.')
        parsed_filter.append('{}={}:{}'.format(tokens[0].strip(), op_map[tokens[1]], tokens[2].strip()))
    return '&'.join(parsed_filter)


def _query_paged(headers, url):
    page_no = 1
    while url:
        bravo_response = requests.get(url, headers=headers)
        if bravo_response.status_code == 400:
            raise BravoException(bravo_response.json().get('error', 'Failed to query data.'))
        elif bravo_response.status_code != 200: 
            raise BravoException("Request failed with error:\n" + str(bravo_response._exception))
        bravo_response_data = bravo_response.json()
        if bravo_response_data['format'] == 'vcf' and page_no == 1:
            for line in bravo_response_data['meta']:
                yield line
            yield bravo_response_data['header']
        for item in bravo_response_data['data']:
            yield item
        url = bravo_response_data['next']
        page_no += 1


def query_region(chromosome, start, end, format_name, filter_expr):
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/region?chrom={}&start={}&end={}&vcf={}'.format(
        BRAVO_API_VERSION, chromosome, start, end, 0 if format_name != 'vcf' else 1)
    if filter_expr:
        query_url = '{}&{}'.format(query_url, parse_filter_expressions(filter_expr))
    for line in _query_paged(headers, query_url):
        if format_name == 'vcf':
            print(line)
        else:
            json.dump(line, sys.stdout); print('')


def query_gene(name, format_name, filter_expr):
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/gene?name={}&vcf={}'.format(
        BRAVO_API_VERSION, name, 0 if format_name != 'vcf' else 1)
    if filter_expr:
        query_url = '{}&{}'.format(query_url, parse_filter_expressions(filter_expr))

    for line in _query_paged(headers, query_url):
        if format_name == 'vcf':
            print(line)
        else:
            json.dump(line, sys.stdout); print('')


def query_variant(variant_id, chromosome, position, format_name):
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    if variant_id is None:
        if chromosome is None or position is None:
            raise BravoException('Provide both "-c,--chromosome" and "-p,--position".')
        query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/variant?chrom={}&pos={}&vcf={}'.format(
            BRAVO_API_VERSION, chromosome, position, 0 if format_name != 'vcf' else 1)
    else:
        if chromosome is not None or position is not None:
            raise BravoException('"-v,--variant" is not allowed together with "-c,--chromosome" and "-p,--position".')
        query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/variant?variant_id={}&vcf={}'.format(
            BRAVO_API_VERSION, variant_id, 0 if format_name != 'vcf' else 1)
    for line in _query_paged(headers, query_url):
        if format_name == 'vcf':
            print(line)
        else:
            json.dump(line, sys.stdout); print('')


def load_region(chromosome, position, filter_expr):
    if not credstore_exists():
        print('No access tokens found. Please login first.')
        return
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    start = position
    end = start + 8000  # approx. 1,000 variants
    region = {'chromosome': chromosome, 'start': start, 'end': end, 'variants': {}}
    query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/region?chrom={}&start={}&end={}&vcf=0'.format(
        BRAVO_API_VERSION, chromosome, start, end)
    if filter_expr:
        query_url = '{}&{}'.format(query_url, parse_filter_expressions(filter_expr))
    for line in _query_paged(headers, query_url):
        region['variants'][line['variant_id']] = line
    return region


def load_version():
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    credstore = read_credstore()
    headers = {'Authorization': 'Bearer {}'.format(credstore['all'][credstore['active']]['access_token'])}
    query_url = 'https://bravo.sph.umich.edu/freeze5/hg38/api/{}/'.format(BRAVO_API_VERSION)
    return _query_nonpaged(query_url, headers)['dataset']


def annotate(filter_expr):
    if not credstore_exists():
        raise BravoException('No access tokens found. Please login first.')
    region = None
    data_version = load_version()
    fileformat_line = sys.stdin.readline()
    if not fileformat_line or not fileformat_line.startswith('##fileformat=VCF'):
        return
    sys.stdout.write('{}\n'.format(fileformat_line.rstrip()))
    for in_line in sys.stdin:
        if in_line.startswith('#'):
            if in_line.startswith('##'):
                sys.stdout.write('{}\n'.format(in_line.rstrip()))
            elif in_line.startswith('#CHROM'):
                sys.stdout.write('##INFO=<ID=BRAVO_AN,Number=1,Type=Integer,Description=\"Number of Alleles in Samples with Coverage from {}\">\n'.format(data_version))
                sys.stdout.write('##INFO=<ID=BRAVO_AC,Number=A,Type=Integer,Description=\"Alternate Allele Counts in Samples with Coverage from {}\">\n'.format(data_version))
                sys.stdout.write('##INFO=<ID=BRAVO_AF,Number=A,Type=Float,Description=\"Alternate Allele Frequencies from {}\">\n'.format(data_version))
                sys.stdout.write('##INFO=<ID=BRAVO_FILTER,Number=A,Type=Float,Description=\"Filter from {}\">\n'.format(data_version))
                sys.stdout.write('{}\n'.format(in_line.rstrip()))
            continue
        in_fields = in_line.rstrip().split('\t', 8)
        chromosome = in_fields[0]
        position = int(in_fields[1])
        ref = in_fields[3]
        alt = in_fields[4]  # assume bi-allelic
        info = in_fields[7]
        if region is None or region['chromosome'] != chromosome or position < region['start'] or position > region['end']:
            region = load_region(chromosome, position, filter_expr)
        variant_id = '{}-{}-{}-{}'.format(chromosome, position, ref, alt)
        bravo_variant = region['variants'].get(variant_id, None)
        if bravo_variant is None:
            sys.stdout.write('{}\n'.format(in_line.rstrip()))
        else:
            new_info = 'BRAVO_AN={};BRAVO_AC={};BRAVO_AF={};BRAVO_FILTER={}'.format(bravo_variant['allele_num'], bravo_variant['allele_count'], bravo_variant['allele_freq'], bravo_variant['filter'])
            if info == '.':
                info = new_info
            else:
                info = '{};{}'.format(info, new_info)
            if len(in_fields) > 8:
                sys.stdout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(chromosome, position, in_fields[2], ref, alt, in_fields[5], in_fields[6], info, in_fields[8]))
            else:
                sys.stdout.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(chromosome, position, in_fields[2], ref, alt, in_fields[5], in_fields[6], info))


if __name__ == '__main__':
    if len(sys.argv) == 1:
        argparser.print_help(sys.stderr)
        sys.exit(1)
    args = argparser.parse_args()
    try:
        if args.command == 'login':
            login()
        elif args.command == 'print-access-token':
            print_access_token()
        elif args.command == 'revoke':
            revoke()
        elif args.command == 'query-meta':
            query_meta()
        elif args.command == 'query-region':
            query_region(args.chromosome, args.start, args.end, args.format, args.filter)
        elif args.command == 'query-gene':
            query_gene(args.gene, args.format, args.filter)
        elif args.command == 'query-variant':
            query_variant(args.variant_id, args.chromosome, args.position, args.format)
        elif args.command == 'annotate':
            annotate(args.filter)
    except BravoException as e:
        sys.stdout.flush()
        print(e, file=sys.stderr)
        sys.exit(1)
