#!/usr/bin/python3

__unittest = True

import subprocess
import re
import unittest
import xmlrunner
from subprocess import CalledProcessError, Popen, PIPE

#-------helpers---------------

def starts_uint( s ):
    matches = re.findall('^\d+', s)
    if matches:
        return (int(matches[0]), len(matches[0]))
    else:
        return (0, 0)

def starts_int( s ):
    matches = re.findall('^-?\d+', s)
    if matches:
        return (int(matches[0]), len(matches[0]))
    else:
        return (0, 0)

def unsigned_reinterpret(x):
    if x < 0:
        return x + 2**64
    else:
        return x

def first_or_empty( s ):
    sp = s.split()
    if sp == [] :
        return ''
    else:
        return sp[0]

#-----------------------------

before_all="""
%macro call 1
mov rax, -1
push rbx
push rbp
push r12 
push r13 
push r14 
push r15 
call %1
cmp r15, [rsp] 
jne convention_error
pop r15
cmp r14, [rsp] 
jne convention_error
pop r14
cmp r13, [rsp] 
jne convention_error
pop r13
cmp r12, [rsp] 
jne convention_error
pop r12
cmp rbp, [rsp] 
jne convention_error
pop rbp
cmp rbx, [rsp] 
jne convention_error
pop rbx
mov rdi, -1
mov rsi, -1
mov rcx, -1
mov r8, -1
mov r9, -1
mov r10, -1
mov r11, -1
%endmacro

%include "lib.inc"

global _start

section .text
convention_error:
    mov rax, 1
    mov rdi, 2
    mov rsi, err_calling_convention
    mov rdx, err_calling_convention.end - err_calling_convention
    syscall
    mov rax, 60
    mov rdi, -41
    syscall

section .data
    err_calling_convention: db "You did not respect the calling convention! Check that you handled caller-saved and callee-saved registers correctly", 10
    .end:
"""

class IOLibraryTest(unittest.TestCase):
    def compile(self, fname, text):
        f = open( fname + '.asm', 'w')
        f.write( text )
        f.close()

        self.assertEqual(subprocess.call( ['nasm', '-f', 'elf64', fname + '.asm', '-o', fname+'.o'] ), 0, 'failed to compile')
        self.assertEqual(subprocess.call( ['ld', '-o' , fname, fname+'.o'] ), 0, 'failed to link')

    def launch(self, fname, input):
        output = b''
        try:
            p = Popen(['./'+fname], shell=None, stdin=PIPE, stdout=PIPE)
            (output, _) = p.communicate(input.encode())
            self.assertNotEqual(p.returncode, -11, 'segmentation fault')
            return (output.decode(), p.returncode)
        except CalledProcessError as exc:
            self.assertNotEqual(exc.returncode, -11, 'segmentation fault')
            return (exc.output.decode(), exc.returncode)

    def perform(self, fname, text, input):
        self.compile(fname, before_all + text)
        return self.launch(fname, input)



    def test_string_length(self):
        inputs = ['asdkbasdka', 'qwe qweqe qe', '']
        for input in inputs:
            text = """
section .data
str: db '""" + input + """', 0

section .text
_start:
    mov rdi, str
    call string_length
    mov rdi, rax
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('string_length', text, input)
            self.assertEqual(code, len(input), 'string_length(%s) returned wrong length: %d' % (repr(input), code))



    def test_print_string(self):
        inputs = ['ashdb asdhabs dahb', ' ', '']
        for input in inputs:
            text = """
section .data
str: db '""" + input + """', 0

section .text
_start:
    mov rdi, str
    call print_string
    xor rdi, rdi
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('print_string', text, input)
            self.assertEqual(output, input, 'print_string(%s) printed wrong string: %s' % (repr(input), repr(output)))



    def test_string_copy(self):
        inputs = ['ashdb asdhabs dahb', ' ', '']
        for input in inputs:
            text = """
section .data
arg1: db '""" + input + """', 0
arg2: times """ + str(len(input) + 1) + """ db  66

section .text
_start:
    mov rdi, arg1
    mov rsi, arg2
    mov rdx, """ + str(len(input) + 1) + """
    call string_copy

    mov rdi, arg2 
    call print_string
    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('string_copy', text, input)
            self.assertEqual(output, input, 'string_copy(%s) put wrong string into buffer: %s' % (repr(input), repr(output)))



    def test_string_copy_too_long(self):
        inputs = ['ashdb asdhabs dahb', ' ', '']
        for input in inputs:
            text = """
section .rodata
err_too_long_msg: db "string is too long", 10, 0

section .data
arg1: db '""" + input + """', 0
arg2: times """ + str(len(input)//2) + """ db  66

section .text
_start:
    mov rdi, arg1
    mov rsi, arg2
    mov rdx, """ + str(len(input)//2) + """
    call string_copy
    test rax, rax
    jnz .good
    mov rdi, err_too_long_msg 
    call print_string
    jmp _exit
    .good:
    mov rdi, arg2 
    call print_string
_exit:
    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('string_copy_too_long', text, input)
            self.assertNotEqual(output.find('too long'), -1, 'string_copy(%s) should have failed, but returned: %s' % (repr(input), repr(output)))



    def test_print_char(self):
        inputs = ['a', ' ', 'c']
        for input in inputs:
            text = """
section .text
_start:
    mov rdi, '""" + input + """'
    call print_char
    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('print_char', text, input)
            self.assertEqual(output, input, 'print_char(%s) printed wrong char: %s' % (repr(input), repr(output)))



    def test_print_uint(self):
        inputs = ['-1', '12345234121', '0', '12312312', '123123']
        for input in inputs:
            text = """
section .text
_start:
    mov rdi, """ + input + """
    call print_uint
    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('print_uint', text, input)
            uinput = str(unsigned_reinterpret(int(input)))
            self.assertEqual(output, uinput, 'print_uint(%s) printed wrong number: %s, expected: %s' % (repr(input), repr(output), repr(uinput)))



    def test_print_int(self):
        inputs = ['-1', '-12345234121', '0', '123412312', '123123']
        for input in inputs:
            text = """
section .text
_start:
    mov rdi, """ + input + """
    call print_int
    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('print_int', text, input)
            self.assertEqual(output, input, 'print_int(%s) printed wrong number: %s' % (repr(input), repr(output)))



    def test_read_char(self):
        inputs = ['-1', '-1234asdasd5234121', '', '   ', '\t   ', 'hey ya ye ya', 'hello world' ]
        for input in inputs:
            text = """
section .text
_start:
    call read_char
    mov rdi, rax
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('read_char', text, input)
            if input == "":
                self.assertEqual(code, 0, 'read_char with empty input should return 0')
            else:
                self.assertEqual(code, ord(input[0]), 'read_char(%d) returned incorrect char: %d' % (ord(input[0]), code))



    def test_read_word(self):
        inputs = ['-1'] # , '-1234asdasd5234121', '', '   ', '\t   ', 'hey ya ye ya', 'hello world' ],

        for input in inputs:
            text = """
section .data
word_buf: times 20 db 0xca

section .text
_start:
    mov rdi, word_buf
    mov rsi, 20 
    call read_word
    mov rdi, rax
    call print_string

    mov rax, 60
    xor rdi, rdi
    syscall
"""
            (output, code) = self.perform('read_word', text, input)
            input_word = first_or_empty(input)
            self.assertEqual(output, input_word, 'read_word(%s) put incorrect word in the buffer: %s, expected: %s' % (repr(input), repr(output), repr(input_word)))



    def test_read_word_length(self):
        inputs = ['-1', '-1234asdasd5234121', '', '   ', '\t   ', '\t   123', 'hey ya ye ya', 'hello world' ]
        for input in inputs:
            text = """
section .data
word_buf: times 20 db 0xca

section .text
_start:
    mov rdi, word_buf
    mov rsi, 20 
    call read_word
    mov rax, 60
    mov rdi, rdx
    syscall
"""
            (output, code) = self.perform('read_word_length', text, input)
            input_word = first_or_empty(input)
            self.assertEqual(code, len(input_word), 'read_word(%s) returned incorrect length: %d, expected: %d' % (repr(input), code, len(input_word)))



    def test_read_word_too_long(self):
        inputs = [ 'asdbaskdbaksvbaskvhbashvbasdasdads wewe', 'short' ]
        for input in inputs:
            text = """
section .data
stub: times 5 db 0xca
word_buf: times 20 db 0xca

section .text
_start:
    mov rdi, word_buf
    mov rsi, 20 
    call read_word
    mov rdi, rax
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('read_word_too_long', text, input)
            input_word = first_or_empty(input)
            if len(input_word) > 19:
                self.assertEqual(code, 0, 'read_word(%s) overflows buffer, but does not fail' % repr(input))
            else:
                self.assertNotEqual(code, 0, 'read_word(%s) does not overflow buffer, but fails' % repr(input))



    def test_parse_uint(self):
        inputs = ["0", "1234567890987654321hehehey", "1" ]
        for input in inputs:
            text = """
section .data
input: db '""" + input  + """', 0

section .text
_start:
    mov rdi, input
    call parse_uint
    push rdx
    mov rdi, rax
    call print_uint
    mov rax, 60
    pop rdi
    syscall
"""
            (output, code) = self.perform('parse_uint', text, input)
            (input_num, input_len) = starts_uint(input)

            self.assertEqual(output, str(input_num), 'parse_uint(%s) parsed wrong number: %s, expected: %s' % (repr(input), repr(output), repr(str(input_num))))
            self.assertEqual(code, input_len, 'parse_uint(%s) returned wrong length: %d, expected: %d' % (repr(input), code, input_len))



    def test_parse_int(self):
        inputs = ["0", "1234567890987654321hehehey", "-1dasda", "-eedea", "-123123123", "1" ]
        for input in inputs:
            text = """
section .data
input: db '""" + input  + """', 0

section .text
_start:
    mov rdi, input
    call parse_int
    push rdx
    mov rdi, rax
    call print_int
    mov rax, 60
    pop rdi
    syscall
"""
            (output, code) = self.perform('parse_int', text, input)
            (input_num, input_len) = starts_int(input)

            if input_len == 0:
                self.assertEqual(output, '0', 'parse_int(%s) should have failed, but parsed %s' % (repr(input), output))
            else:
                self.assertEqual(output, str(input_num), 'parse_int(%s) parsed wrong number: %s, expected: %s' % (repr(input), repr(output), repr(str(input_num))))
                self.assertEqual(code, input_len, 'parse_int(%s) returned wrong length: %d, expected: %d' % (repr(input), code, input_len))



    def test_string_equals(self):
        inputs = ['ashdb asdhabs dahb', ' ', '', "asd" ]
        for input in inputs:
            text = """
section .data
str1: db '""" + input + """',0
str2: db '""" + input + """',0

section .text
_start:
    mov rdi, str1
    mov rsi, str2
    call string_equals
    mov rdi, rax
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('string_equals', text, input)
            self.assertEqual(code, 1, 'string_equals(%s, %s) should return 1' % (repr(input), repr(input)))



    def test_string_not_equals(self):
        inputs = ['ashdb asdhabs dahb', ' ', '', "asd" ]
        for input in inputs:
            text = """
section .data
str1: db '""" + input + """',0
str2: db '""" + input + """!!',0

section .text
_start:
    mov rdi, str1
    mov rsi, str2
    call string_equals
    mov rdi, rax
    mov rax, 60
    syscall
"""
            (output, code) = self.perform('string_not_equals', text, input)
            self.assertEqual(code, 0, 'string_equals(%s, %s!!) should return 0' % (repr(input), repr(input)))



if __name__ == "__main__":
    with open('report.xml', 'w') as report:
        unittest.main(testRunner=xmlrunner.XMLTestRunner(output=report), failfast=False, buffer=False, catchbreak=False)