#!/usr/bin/env python
""" Show me 'weird' DNS packets """
from __future__ import print_function
import os
import argparse
from pprint import pprint

# Local imports
from chains.utils import signal_utils, net_utils, compat
from chains.sources import packet_streamer
from chains.links import packet_meta, reverse_dns, transport_meta, dns_meta
from chains.sinks import packet_summary, packet_printer

def run(iface_name=None, bpf=None, verbose=None, max_packets=100):
    """Run the Weird DNS Example"""

    # Create the classes
    streamer = packet_streamer.PacketStreamer(iface_name=iface_name, bpf=bpf, max_packets=max_packets)
    meta = packet_meta.PacketMeta()
    rdns = reverse_dns.ReverseDNS()
    tmeta = transport_meta.TransportMeta()
    dmeta = dns_meta.DNSMeta()
    printer = packet_summary.PacketSummary()

    # Set up the chain
    meta.link(streamer)
    rdns.link(meta)
    tmeta.link(rdns)
    dmeta.link(tmeta)

    # Pull the chain (here since we have a 'condition' we pull and 'push' manually)
    for item in dmeta.output_stream:

        # Does this have dns and weird in the dns
        if 'dns' in item and  'weird' in item['dns']:

            # Print out packet info
            packet = item['packet']
            if 'src_domain' in packet:
                print('%s(%s) --> %s(%s)' % (net_utils.inet_to_str(packet['src']), packet['src_domain'],
                    net_utils.inet_to_str(packet['dst']), packet['dst_domain']))
            else:
                print('%s --> %s' % (net_utils.inet_to_str(packet['src']), net_utils.inet_to_str(packet['dst'])))

            # Print DNS Info
            print('Weird DNS: ')
            for key, value in compat.iteritems(item['dns']['weird']):
                print('\t{:s} = {:s}'.format(key, str(value)))
            print('Queries:')
            pprint(item['dns']['queries'])
            print('Asnwers:')
            pprint(item['dns']['answers'])
            print('Flags:')
            pprint(item['dns']['flags'])
            print('\n')

def test():
    """Test the Simple Packet Printer Example"""
    from chains.utils import file_utils

    # For the test we grab a file, but if you don't specify a
    # it will grab from the first active interface
    data_path = file_utils.relative_dir(__file__, '../data/dns.pcap')
    run(iface_name=data_path)

def my_exit():
    """Exit on Signal"""
    print('Goodbye...')

if __name__ == '__main__':

    # Collect args from the command line
    parser = argparse.ArgumentParser()
    parser.add_argument('-bpf', type=str, help='BPF Filter for PacketStream Class')
    parser.add_argument('-v', '--verbose', action="store_true", help='List full DNS packet when weird')
    parser.add_argument('-m', '--max-packets', type=int, default=100, help='How many packets to process (0 for infinity)')
    parser.add_argument('-p', '--pcap', type=str, help='Specify a pcap file instead of reading from live network interface')
    args, commands = parser.parse_known_args()
    if commands:
        print('Unrecognized args: %s' % commands)


    # Pcap file may have a tilde in it
    if args.pcap:
        args.pcap = os.path.expanduser(args.pcap)

    with signal_utils.signal_catcher(my_exit):
        run(iface_name=args.pcap, bpf=args.bpf, verbose=args.verbose, max_packets=args.max_packets)
