import sys, traceback

# (C) 2008 - Cameron Hotchkies. (cameron AT 0x90 DOT org)
# IDAPython functions for cleaning up an OSX/Objective C binary
# Current version as of REcon 2008 (June 13 2008)

# Used for lookups
optype = {0:"void", 1:"reg", 2:"mem", 3:"phrase", 4:"displ", 5:"imm", 6:"far", 7:"near"}
RegNo = {-1:"R_none",0:"R_ax",1:"R_cx",2:"R_dx",  3:"R_bx",  4:"R_sp",  5:"R_bp",  6:"R_si",  7:"R_di",  8:"R_r8",  9:"R_r9",  10:"R_r10",  11:"R_r11",  12:"R_r12",  13:"R_r13",  14:"R_r14",15:"R_r15"}


#####################################################################################
# Finish functions by identifying prologues:
#####################################################################################
def rebuild_functions_from_prologues():
    seg_start = SegByName("__text")
    seg_end = SegEnd(seg_start)

    cursor = seg_start

    while cursor < seg_end:
        cursor = find_not_func(cursor, 0x1)
    
        # push EBP; mov EBP,ESP
        if (Byte(cursor) == 0x55 and Byte(cursor+1) == 0x89 and Byte(cursor+2)==0xE5):
            MakeFunction(cursor, BADADDR)
        else:
            cursor = FindBinary(cursor, 0x1, "55 89 E5", 16)
            if (GetFunctionName(cursor) == ""):
                MakeFunction(cursor, BADADDR)

def find_anchor(address):
    func_start = GetFunctionAttr(address, FUNCATTR_START)
    func_end = GetFunctionAttr(address, FUNCATTR_END)

    cursor = func_start

    while cursor < func_end:
        # call $+5; pop ebx
        if (Byte(cursor) == 0xE8 and Dword(cursor+1) == 0x00000000 and Byte(cursor+5)==0x5B):
            return cursor+5
        elif GetMnem(cursor) == "call":
            subfunc = Rfirst0(cursor)
            #   mov ebx, [esp+0]; retn
            if (Byte(subfunc) == 0x8B and Byte(subfunc+1) == 0x1C and Byte(subfunc+2) == 0x24 and Byte(subfunc+3) == 0xC3):
                return NextNotTail(cursor)
            else:
                cursor = NextNotTail(cursor)
        else:
            cursor = NextNotTail(cursor)
    return BADADDR

# print "0x%08x: Anchor" % find_anchor(ScreenEA())

#####################################################################################
# Rebuild a jump table that was mangled
#
# The fn_base is the value (usually ebx) that holds the function anchor address
#  ie. call +5; pop ebx
#  or call -> mov ebx, [esp+0]; retn
#
# The jmp_table_offset is the extra hex value in the lookup
# ie. mov eax, [ebx +eax*4 + jmp_table_offset]; add eax, ebx; jmp eax
# If this uses ScreenEA(), make sure the cursor is on the jmp line
#
# Afterwards, go undefine the block and make sure all the code targets are infact
# Set as code, or graphing will bunk out.
#####################################################################################

def rebuild_jump_table(fn_base, jmp_table_offset, address=None):
    if not address:
        address = ScreenEA()
           
    fn_base = find_anchor(address)
    jmp_table = jmp_table_offset + fn_base
   
    counter = 0;
    entry = Dword(jmp_table + 4*counter) + fn_base

    while NextFunction(address) == NextFunction(entry):
        counter += 1
        AddCodeXref(address, entry, fl_JN)
        entry = Dword(jmp_table + 4*counter) + fn_base

    print "0x%08x: end jump table" % entry


#####################################################################################
# This was wrong, but I'll keep it for reference:
#####################################################################################
#
#if (GetOpType(ScreenEA(), 1) == 4): # Base+Index+Disp
#    disp = GetOpnd(ScreenEA(),1)[5:-2]
#    dref = int(disp,16) + int(GetFunctionAttr(ScreenEA(), FUNCATTR_START))
#    add_dref(ScreenEA(), dref, dr_R)
#    MakeComm(ScreenEA(), "0x%08x: DREF!" % dref)
#    print "0x%08x: DREF!" % dref
#
#####################################################################################
# Copy the double referenced string into a local comment:
#####################################################################################
def deref_selector_strings(address=None):
    if not address:
        address = ScreenEA()
        
    drefea = Dfirst(address)
    # This should be sanity checked
    drefea = Dfirst(drefea)
    selector = get_ascii_contents(drefea, get_max_ascii_length(drefea, 0), 0)
    
    # This is just kept for testing
    if address == ScreenEA() or VERBOSE==True:
        set_cmt(address, selector, False) 

    return selector
#####################################################################################
# Rename variables passed to _objc_msgSend():
#####################################################################################
def rename_msgSend_args(function_addr=None):
    if not function_addr:
        function_addr = ScreenEA()
    
    frm = GetFrame(function_addr)
    firstM = GetFirstMember(frm)

    SetMemberName(frm, firstM, "msgSend_recipient")
    SetMemberName(frm, firstM+4, "msgSend_selector")

#####################################################################################
# Iterate definitions in __inst_meth/__cls_meth section:
#####################################################################################
def iterate_fn_defs(section_name, category=False):
    seg = SegByName(section_name)
    segend = SegEnd(seg)
    cursor = seg

    while cursor < segend:
        if category and SegName(DfirstB(cursor)) != "__category":
            cursor += 1
        else:
            # zb is always zero
            zb = Dword(cursor)
            cursor += 4
            count = Dword(cursor)
            cursor +=4
        
            i = 0
            for i in range(0,count):
                # cursor+4 => is the Type Encoded retval/argtypes see pg 127 of 
                # "The Objective-C 2.0 Programming Language". This will be used later.
                selector = get_ascii_contents(Dword(cursor), get_max_ascii_length(Dword(cursor), 0), 0);
                func = Dword(cursor + 8);
                retval = MakeNameEx(func, selector, 0x100);

                nc = 0
                while retval == False and nc < 100:
                    nc = nc+1
                    newname = "%s__%d" % (selector, nc)
                    retval = MakeNameEx(func, newname, 0x100)

                # Advance the cursor
                cursor += 12

        # This may not be the best way to go about this, but it's better than nothing
        # for now. Later it should be pulled from the class defs
        while Name(cursor) == "" and cursor < segend:
            cursor += 1
            
def rename_method_sections():
    iterate_fn_defs("__inst_meth")
    iterate_fn_defs("__cls_meth")
    # categories are weird
    iterate_fn_defs("__cat_inst_meth", True)
    iterate_fn_defs("__cat_cls_meth", True)
    
#####################################################################################
# Rename class structures:
#####################################################################################
def rename_class_structures():
    class_seg = SegByName("__class")
    segend = SegEnd(class_seg)

    cursor = class_seg

    while cursor < segend:
       cls_name = get_ascii_contents(Dword(cursor+8), get_max_ascii_length(Dword(cursor+8), 0), 0)
   
       MakeNameEx(cursor, cls_name, 0x0)
       MakeNameEx(Dword(cursor+0x1C), "%s__mthd" % cls_name, 0x0)
       MakeNameEx(Dword(cursor), "%s__meta" % cls_name, 0x0)
       MakeNameEx(Dword(cursor+0x18), "%s__ivars" % cls_name, 0x0)

       # sizeof(__class_struct-8) + 24 bytes padding
       cursor += 0x28 + 24


# Thanks pedram!
def get_marked_next():
    slot = 1;
    # loop until we find an empty slot
    
    # In IDAPython it's BADADDR, not -1
    while GetMarkedPos(slot) != BADADDR:
        slot+=1
        
        if slot > 1024:
            raise Exception("No slots found.")
        
    return slot

class StackChangedException(Exception):
    def __init__(self, value):
        self.parameter = value
    def __str__(self):
        return repr(self.parameter)

#####################################################################################
# Variable traceback:
#####################################################################################
def traceback(addr, variable_test):
    # For now we'll assume it's a stack variable
    
    cursor = RfirstB(addr)
    ocursor = addr
    
    STATE_FIND_DEF = 0x0
    STATE_FIND_DEF_SRC = 0x1
    
    # TODO: watch EAXes for calls
    
    # Save the state of the starting point
    state = 0
    spd = GetSpd(addr)
    trace_state = o_void
    targetop = None
    
    # We only want to deal with straight shots, keep it simple for now
    # Assumes exactly one level of indierection
    while RnextB(ocursor, cursor) == BADADDR:
        # Only bother with the spd check if it's not a register being tracked
        if not (STATE_FIND_DEF_SRC and trace_state == o_reg):
            if GetSpd(cursor) != spd:
                # Stack offset changed, not even trying to deal with this
                print ("0x%08x: Bailing" % cursor)
                MarkPosition(cursor, 1,1,1, get_marked_next(), "0x%08x: Shifting SP for traceback" % cursor)
                raise StackChangedException, "The stack changed at 0x%08x" % cursor
                    
        ua_ana0(cursor) # load cmd struct for possible tests
        if state == STATE_FIND_DEF and GetMnem(cursor) == "mov":
            if variable_test(cursor, 0):
                targetop = GetOpnd(cursor,1)
                trace_state = GetOpType(cursor, 1) # most likely o_reg
                state = STATE_FIND_DEF_SRC
        elif targetop == "eax" and GetMnem(cursor) == "call":
            return cursor
        elif state == STATE_FIND_DEF_SRC and GetMnem(cursor) == "mov":
            if GetOpnd(cursor,0) == targetop:
                return cursor

        ocursor = cursor
        cursor = RfirstB(cursor)
        
#####################################################################################
# Rework functions with xrefs to _objc_msgSend:
#####################################################################################
def rebuild_function_contents():
    seg = SegByName("__jump_table")
    segend = SegEnd(seg)

    # This should eventually iterate the other 
    # function proxies.
    for fn in Functions(seg, segend):
        if Name(fn) == "_objc_msgSend":
            msgsend = fn
            break

    fn = RfirstB(msgsend) 
    fns = set([fn])
    
    while fn != BADADDR:
        fn = RnextB(msgsend, fn)
        fns.add(fn)

    fns.remove(BADADDR)

    frames = set([])
    trouble_fns = set([])

    # Recipient => [esp]
    rec_lambda = lambda cursor, opnum: GetOpType(cursor, opnum) == o_phrase and RegNo[get_instruction_operand(cvar.cmd,opnum).reg] == "R_sp"
    # Selector => [esp+4]
    sel_lambda = lambda cursor, opnum: GetOpType(cursor, opnum) == o_displ and RegNo[get_instruction_operand(cvar.cmd,opnum).reg] == "R_sp" and get_instruction_operand(cvar.cmd,opnum).addr == 4

    rec_addr = None

    for fn in fns:
       try:
          if fn not in frames:
              print "[+] working in 0x%08x" % fn
              rename_msgSend_args(fn)
              frames.add(GetFrame(fn))

          if fn not in trouble_fns:
              if GetMnem(fn) == "call": #handle jmps
                  sel_addr = traceback(fn, sel_lambda)
                  rec_addr = traceback(fn, rec_lambda)
                  if rec_addr:
                      if GetOpType(rec_addr, 1) == 2:
                          recip = deref_selector_strings(rec_addr)
                      elif GetMnem(rec_addr) == "call":
                          recip = "a"
                      else:
                          recip = GetOpnd(rec_addr, 1)
                      comment = "a = [%s %s]" % (recip, deref_selector_strings(sel_addr))
                      set_cmt(fn, comment, False)
                  else:
                      trouble_fns.add(fn)
            
       except StackChangedException, sce:
           print "Missed a spot."
       except:
           print "0x%08x: Error!" % fn    
           print "%s: recaddr!" % rec_addr.__class__
           raise
           return # TEMP for testing

# remove the ScreenEA before using
def rename_args(cursor, num_args=1):
    RegNo = {-1:"R_none",0:"R_ax",1:"R_cx",2:"R_dx",  3:"R_bx",  4:"R_sp",  5:"R_bp",  6:"R_si",  7:"R_di",  8:"R_r8",  9:"R_r9",  10:"R_r10",  11:"R_r11",  12:"R_r12",  13:"R_r13",  14:"R_r14",15:"R_r15"}

    arg_name = "callee_arg"
    cursor = ScreenEA()

    if GetOpType(cursor, 0)== o_phrase and RegNo[get_instruction_operand(cvar.cmd,0).reg] == "R_sp":
            arg_name += "1"
    elif GetOpType(cursor, 0) == o_displ and RegNo[get_instruction_operand(cvar.cmd,0).reg] == "R_sp":
            arg_name += str(get_instruction_operand(cvar.cmd,0).addr/4 + 1)

    OpAlt(cursor, 0, arg_name)


#####################################################################################
# main() code:
#####################################################################################
VERBOSE = True

# This must ALWAYS be called first!
rebuild_functions_from_prologues()
print "[+] rebuilt from prologues"
rename_class_structures()
print "[+] renamed class structures"
rename_method_sections()


#print "[+] renamed methods"
#rebuild_function_contents()



# comment ebx=[esp]
#MakeComm(ScreenEA(), "ebx = 0x%08x" % NextNotTail(ScreenEA()))

# Fix broken code flow snippet:
# AddCodeXref(ScreenEA(), NextNotTail(ScreenEA()), fl_F)

