#!python

# Run BSP code in Stackless Python with one task per virtual processor
#
# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
# last revision: 2004-3-29
#

import code, copy, getopt, ihooks, imp, os, pdb, sys, time, traceback, types
from StringIO import StringIO
import stackless
import Scientific.BSP

debug = 0

#
# End of task exception
#
class TaskEnded(Exception):

    def __init__(self, pid):
        self.pid = pid

#
# Forced termination exception
#
class ForcedTermination(Exception):
    pass

#
# The virtual parallel machine
#
class VirtualBSPMachine:

    def __init__(self, nproc, show_runtime=0):
        pids = range(nproc)
        self.nproc = nproc
        self.show_runtime = runtime
        self.current_pid = 0
        self.debugger = None
        self.to_channels = [stackless.channel() for pid in pids]
        self.back_channels = [stackless.channel() for pid in pids]
        self.name_spaces = [{'__name__': '__main__'} for pid in pids]
        self.mailboxes = [[] for pid in pids]
        self.runtime = [0. for pid in pids]
        self.output = [None for pid in pids]
        self.exception = [None for pid in pids]

    def currentPid(self):
        return self.current_pid

    def numberOfProcessors(self):
        return nproc

    def isExtendedName(self, name):
        return name[:2] == '_P' and name[7] == '_'

    def extendedName(self, name):
        return ("_P%05d_" % self.current_pid) + name

    def baseName(self, name):
        if self.isExtendedName(name):
            return name[8:]
        else:
            return name

    def virtualProcessor(self, pid):
        self.back_channels[pid].send(None)
        self.to_channels[pid].receive()
        self.exception[pid] = None
        try:
            exec self.code in self.name_spaces[pid]
        except:
            self.exception[pid] = sys.exc_info()
            self.showtraceback()
        if self.debugger is not None:
            sys.settrace(None)
        if self.exception[pid] is None \
               or self.exception[pid][0] != ForcedTermination:
            self.back_channels[pid].send_exception(TaskEnded, pid)

    def run(self, code, buffered_output = 0):
        self.code = code
        pids = range(self.nproc)
        self.tasklets = []
        for pid in pids:
            if buffered_output:
                self.output[pid] = StringIO()
            task = stackless.tasklet(self.virtualProcessor)(pid)
            self.tasklets.append(task)
            task.setatomic(1)
            task.run()
        for pid in pids:
            self.back_channels[pid].receive()

        super_step = 0
        real_stdout = sys.stdout
        real_stderr = sys.stderr
        while 1:
            terminated = []
            if debug: print "*** Superstep ", super_step
            for self.current_pid in pids:
                try:
                    if debug: print "*** Running ", self.current_pid
                    if buffered_output:
                        sys.stdout = self.output[self.current_pid]
                        sys.stderr = self.output[self.current_pid]
                    start_time = time.time()
                    if self.debugger is not None:
                        if self.current_pid in self.debugged_pids:
                            print "*** Debugging pid", self.current_pid
                            self.debugger.reset()
                            sys.settrace(self.debugger.trace_dispatch)
                        else:
                            print "*** Not debugging pid", self.current_pid
                    self.to_channels[self.current_pid].send(None)
                    self.back_channels[self.current_pid].receive()
                    sys.settrace(None)
                    self.runtime[self.current_pid] += time.time()-start_time
                    if buffered_output:
                        sys.stdout = real_stdout
                        sys.stderr = real_stderr
                    if debug: print "*** ", self.current_pid, " at barrier"
                except TaskEnded, exception:
                    if buffered_output:
                        sys.stdout = real_stdout
                        sys.stderr = real_stderr
                    if debug: print "*** Received exception", exception.pid
                    terminated.append(exception.pid)
            self.messages = self.mailboxes
            self.mailboxes = [[] for pid in pids]
            if debug: print "*** End of superstep ", super_step
            super_step += 1
            if terminated:
                for pid in pids:
                    if pid not in terminated:
                        self.to_channels[pid].send_exception(ForcedTermination)
                break
        if self.show_runtime:
            print "Runtimes:"
            for pid in range(nproc):
                print "  Processor %d: %f s" % (pid, self.runtime[pid])

    def showtraceback(self):
        try:
            type, value, tb = sys.exc_info()
            sys.last_type = type
            sys.last_value = value
            sys.last_traceback = tb
            tblist = traceback.extract_tb(tb)
            del tblist[:1]
            list = traceback.format_list(tblist)
            if list:
                list.insert(0, "Traceback (most recent call last):\n")
            list[len(list):] = traceback.format_exception_only(type, value)
        finally:
            tblist = tb = None
        map(sys.stderr.write, list)

    def run_debug(self, command, debugger, pids):
        self.debugger = debugger
        self.debugged_pids = pids
        self.run(command, 0)
        self.debugger = None

    def put(self, object, pid_list):
        if self.debugger:
            sys.settrace(None)
        for pid in pid_list:
            self.mailboxes[pid].append(copy.deepcopy(object))
            
    def send(self, messages):
        if self.debugger:
            sys.settrace(None)
        for pid, data in messages:
            self.put(data, [pid])

    def sync(self):
        if self.debugger:
            sys.settrace(None)
        pid = self.current_pid
        if debug: print "--- ", pid, " at barrier"
        self.back_channels[pid].send(None)
        if debug: print "--- ", pid, " waiting"
        self.to_channels[pid].receive()
        if debug: print "--- ", pid, " released"
        messages = self.messages[pid]
        self.messages[pid] = None
        return messages

#
# Change module imports such that every virtual processor gets its
# own copy of all modules except for builtin modules.
#
class Hooks(ihooks.Hooks):

    def is_builtin(self, name):
        return imp.is_builtin(machine.baseName(name))
    def init_builtin(self, name):
        return imp.init_builtin(machine.baseName(name))
    def is_frozen(self, name):
        return imp.is_frozen(machine.baseName(name))
    def init_frozen(self, name):
        return imp.init_frozen(machine.baseName(name))
    def get_frozen_object(self, name):
        return imp.get_frozen_object(machine.baseName(name))

class ModuleLoader(ihooks.ModuleLoader):

    def find_module_in_dir(self, name, dir, allow_packages=1):
        if dir is None:
            return self.find_builtin_module(name)
        name = machine.baseName(name)
        return ihooks.ModuleLoader.find_module_in_dir(self, name, dir,
                                                      allow_packages)

class ModuleImporter(ihooks.ModuleImporter):

    def __init__(self, loader = None):
        ihooks.ModuleImporter.__init__(self, loader)
        for name in self.modules.keys():
            if not (machine.isExtendedName(name)
                    or name == "Scientific"
                    or name[:11] == "Scientific."):
                for pid in range(machine.nproc):
                    self.modules[machine.extendedName(name, pid)] = \
                                                            self.modules[name]

    def import_module(self, name, globals=None, locals=None, fromlist=None):
        name = machine.extendedName(name)
        return ihooks.ModuleImporter.import_module(self, name,
                                                   globals, locals, fromlist)

    def import_it(self, partname, fqname, parent, force_load=0):
        if not partname:
            raise ValueError("Empty module name")
        if not force_load:
            try:
                return self.modules[fqname]
            except KeyError:
                pass
        try:
            path = parent and parent.__path__
        except AttributeError:
            return None
        stuff = self.loader.find_module(partname, path)
        if not stuff:
            return None
        m = self.loader.load_module(fqname, stuff)
        if parent:
            setattr(parent, machine.baseName(partname), m)
        return m

    def find_head_package(self, parent, name):
        if '.' in name:
            i = name.find('.')
            head = name[:i]
            tail = name[i+1:]
        else:
            head = name
            tail = ""
        if parent:
            qname = "%s.%s" % (parent.__name__, machine.baseName(head))
        else:
            qname = head
        q = self.import_it(head, qname, parent)
        if q: return q, tail
        if parent:
            qname = head
            parent = None
            q = self.import_it(head, qname, parent)
            if q: return q, tail
        raise ImportError("No module named " + qname)

#
# An interactive parallel console
#
class VirtualConsole(code.InteractiveConsole):

    def __init__(self, virtual_machine):
        code.InteractiveConsole.__init__(self)
        self.virtual_machine = virtual_machine

    def runcode(self, compiled_code):
        try:
            self.virtual_machine.run(compiled_code, 1)
            self.showOutput()
        except SystemExit:
            raise
        except:
            self.showOutput()
            self.virtual_machine.showtraceback()
        else:
            if code.softspace(sys.stdout, 0):
                print

    def showOutput(self):
        for pid in range(self.virtual_machine.numberOfProcessors()):
            output = self.virtual_machine.output[pid].getvalue()
            if output:
                sys.stdout.write(("-- Processor %d " % pid) + 40*'-' + '\n')
                sys.stdout.write(output)

    def push(self, line): 
        if line and line[0] == '!':
            command = line.strip()[1:]
            exec command in {'pm': self.pdb_pm, 'run': self.pdb_run}
            return 0
        return code.InteractiveConsole.push(self, line)

    def pdb_pm(self, pid):
        pdb.post_mortem(self.virtual_machine.exception[pid][-1])

    def pdb_run(self, pids, command):
        try:
            command = compile(command, '<string>', 'exec')
        except SyntaxError:
            self.showsyntaxerror()
            return
        if type(pids) == type(0):
            pids = [pids]
        debugger = pdb.Pdb()
        self.virtual_machine.run_debug(command, debugger, pids)
        debugger.quitting = 1
        sys.settrace(None)
        
options, args = getopt.getopt(sys.argv[1:], 'it')
interactive = 0
runtime = 0
for option, value in options:
    if option == '-i':
        interactive = 1
    if option == '-t':
        runtime = 1
if len(args) < 2:
    interactive = 1

nproc = int(args[0])
machine = VirtualBSPMachine(nproc, runtime)
sys.virtual_bsp_machine = machine

hooks = Hooks()
loader = ModuleLoader(hooks)
importer = ModuleImporter(loader)
sys.setrecursionlimit((nproc+1)*1000)
importer.install()

cprt = 'Type "copyright", "credits" or "license" for more information.'
banner = "Python %s on %s\n%s\n(Parallel console, %d processors)" % \
         (sys.version, sys.platform, cprt, nproc)

if len(args) > 1:
    script_name = args[1]
    script = file(script_name).read()
    compiled_code = compile(script, script_name, 'exec')
    machine.run(compiled_code)
    banner = ''

if interactive:
    console = VirtualConsole(machine)
    console.interact(banner)

