#!/usr/bin/env python
#
# Copyright (C) 2012 Yubico AB. All rights reserved.
#
"""
This is a tool to decrypt AEADs generated using a YubiHSM, provided that
you know the key_handle used as well as the AES key used.

This can be used together with yhsm-generate-keys to generate a number
of AEADs, and then decrypt them to program YubiKeys accordingly.
"""

import os
import sys
import fcntl
import argparse
import traceback

import pyhsm

args = None
yknum = 0

def parse_args():
    """
    Parse the command line arguments
    """
    parser = argparse.ArgumentParser(description = 'Decrypt AEADs',
                                     add_help = True,
                                     formatter_class = argparse.ArgumentDefaultsHelpFormatter,
                                     )
    parser.add_argument('-v', '--verbose',
                        dest='verbose',
                        action='store_true', default=False,
                        help='Enable verbose operation',
                        )
    parser.add_argument('--debug',
                        dest='debug',
                        action='store_true', default=False,
                        help='Enable debug operation',
                        )
    parser.add_argument('--format',
                        dest='format',
                        default='raw',
                        help='Select output format (raw or yubikey-csv)',
                        )
    parser.add_argument('--print-filename',
                        dest='print_filename',
                        action='store_true', default=False,
                        help='Prefix each row with the AEAD filename',
                        )
    parser.add_argument('--key-handle',
                        dest='key_handle',
                        help='Key handle used when generating the AEADs.',
                        metavar='HANDLE',
                        )
    parser.add_argument('--aes-key',
                        dest='aes_key',
                        required=True,
                        help='AES key used when generating the AEADs.',
                        metavar='HEXSTR',
                        )
    parser.add_argument('paths',
                        nargs='+',
                        help='Files and/or directories to process.',
                        metavar='FILE-OR-DIR'
                        )
    return parser.parse_args()

def process_file(path, fn, args):
    """
    The main function for reading a file and decrypting it.
    """
    full_fn = os.path.join(path, fn)

    if args.debug:
        print "Loading AEAD : %s" % full_fn

    aead = pyhsm.aead_cmd.YHSM_GeneratedAEAD(None, None, '')
    aead.load(full_fn)

    if not aead.nonce:
        # AEAD file version 0, need to fill in nonce etc.
        if args.key_handle is None:
            sys.stderr.write("ERROR: AEAD in file %s does not include key_handle, and none provided.\n" % (full_fn))
            return False
        aead.key_handle = pyhsm.util.key_handle_to_int(args.key_handle)
        aead.nonce = pyhsm.yubikey.modhex_decode(fn)

    aes_key = args.aes_key.decode('hex')
    if args.debug:
        print(aead)
        print "AEAD len %i : %s" % (len(aead.data), aead.data.encode('hex'))
    pt = pyhsm.soft_hsm.aesCCM(aes_key, aead.key_handle, aead.nonce, aead.data, decrypt = True)

    if args.print_filename:
        print("%s " % (full_fn)),

    if args.format == 'raw':
        print(pt.encode('hex'))
    elif args.format == 'yubikey-csv':
        key = pt[:pyhsm.defines.KEY_SIZE]
        uid = pt[pyhsm.defines.KEY_SIZE:]
        access_code = '00' * 6
        timestamp = ''
        global yknum
        yknum += 1
        print("%i,%s,%s,%s,%s,%s,,,,," % (yknum,
                                          pyhsm.yubikey.modhex_encode(aead.nonce.encode('hex')),
                                          uid.encode('hex'),
                                          key.encode('hex'),
                                          access_code,
                                          timestamp,
                                          ))

    return True

def walk_dir(path, args):
    """
    Check all files in `path' to see if there is any requests that
    we should send out on the bus.
    """
    if args.debug:
        print "Walking %s" % path

    for root, _dirs, files in os.walk(path):
        for fn in files:
            try:
                if not process_file(root, fn, args):
                    return False
            except Exception, e:
                print traceback.format_exc()
                return False
    return True

def main():
    """ Main function when running as a program. """
    global args
    args = parse_args()

    for path in args.paths:
        if os.path.isdir(path):
            if not walk_dir(path, args):
                return False
        else:
            try:
                if not process_file('', path, args):
                    return False
            except Exception, e:
                print traceback.format_exc()
                return False

if __name__ == '__main__':
    if main():
        sys.exit(0)
    sys.exit(1)
