#-*- coding: utf-8 -*-

# Copyright 2012 Calculate Ltd. http://www.calculate-linux.org
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import pickle, random
import threading
import sys, os
import traceback
from traceback import print_exc
from api_types import ReturnProgress,ViewParams,Integer,ViewInfo
from decorators import Dec
from calculate.lib.cl_lang import setLocalTranslate,getLazyLocalTranslate
from soaplib.serializers.primitive import Integer
from soaplib.service import rpc
setLocalTranslate('cl_core3',sys.modules[__name__])
__ = getLazyLocalTranslate(_)

def commonView(self,sid,params,arg):
    dv = self.get_cache(sid,arg,"vars")
    if not dv:
        dv = getattr(self,"%s_vars"%arg)()
    else:
        dv.processRefresh()
    view = ViewInfo(dv,viewparams=params)
    self.set_cache(sid, arg, "vars",dv,smart=False)
    return view

def catchExcept(*skipException):
    class wrapper:
        def __init__(self,f):
            self.f = f
            self.func_name = f.func_name
            self.func_code = f.func_code
            self.__doc__ = f.__doc__
            self.__name__ = f.__name__
        
        def __call__(self,*args,**kwargs):
            try:
                return self.f(*args,**kwargs)
            except BaseException as e:
                from calculate.core.server.api_types import ViewInfo,GroupField,Field
                view = ViewInfo(groups=[])
                group = GroupField(name=_("Error"),last=True)
                group.fields = []
                group.fields.append(Field(
                        name = "error",
                        label = str(e),
                        default = 'color:red;',
                        element = "error"))
                view.groups.append(group)
                if not any(isinstance(e,x) for x in skipException):
                    for i in apply(traceback.format_exception, sys.exc_info()):
                        print i,

                return view
    return wrapper

def shortTraceback(e1,e2,e3):
    """
    Return short traceback
    """
    frame = e3
    #for i in apply(traceback.format_exception, (e1,e2,e3)):
    #    print i,
    while(frame.tb_next):
        frame = frame.tb_next
    module,part = os.path.split(frame.tb_frame.f_code.co_filename)
    if part.endswith('.py'):
        part = part[:-3]
    fallbackmod = part
    modname = [part]
    while module != '/' and not module.endswith('site-packages'):
        module,part = os.path.split(module)
        modname.insert(0,part)
    if module.endswith('site-packages'):
        modname = ".".join(modname)
    else:
        modname = fallbackmod
    return "%s:%s(%s:%s)"%(e1.__name__,e2.message,modname,frame.tb_lineno)

def safetyWrapper(native_errors=(Exception,),
                  man_int=__("Manually interrupted"),
                  post_action=lambda self,x:None,
                  failed_message="",
                  success_message=""):
    """
    Standard decorator for logical method called by wsdl
    """
    def wrapper(f):
        def tmp(self,*args,**kwargs):
            error = None
            try:
                try:
                    res = f(self,*args,**kwargs)
                except EOFError as e:
                    error = str(e)
                except native_errors as e:
                    error = str(e)
                except Exception as e:
                    error = shortTraceback(*sys.exc_info())
                except KeyboardInterrupt:
                    error = str(man_int)
                if error:
                    self.printERROR(error)
                try:
                    post_action(self,error)
                except native_errors as e:
                    error = str(e)
                    self.printERROR(error)
                except KeyboardInterrupt:
                    pass
                if error:
                    mess = str(failed_message)
                    if mess:
                        self.printERROR(str(failed_message))
                    return False
                mess = str(success_message)
                if mess:
                    self.printSUCCESS(mess)
                self.endTask()
                return res
            except (BaseException,) as e:
                error = shortTraceback(*sys.exc_info())
                self.printERROR(error)
                return False
            finally:
                try:
                    if hasattr(self.clVars,"close"):
                        self.clVars.close()
                except (BaseException),e:
                    error = ""
                    error = shortTraceback(*sys.exc_info())
                    self.printERROR(error)
                    return False
                finally:
                    self.endFrame()
            return res
        return tmp
    return wrapper

class CoreWsdl():
    # client signals about presence
    def active_clients (self, sid):
#        curThread = threading.currentThread()
#        REMOTE_ADDR = curThread.REMOTE_ADDR
        self.get_lang(sid)
        if sid > 0 and sid < self.max_sid:
            try:
                # open file its session
                sid_file = self.sids+"/%d.sid" %sid
                if not os.path.isfile(sid_file):
                    return 1
                with open(sid_file) as fd:
                    # read information about session
                    sid_inf = pickle.load(fd)
                    # reset counters
                    sid_inf[1] = 0
                    sid_inf[2] = 0
                fd.close()
                if not os.path.isfile(sid_file):
                    return 1
                fd = open(sid_file,"w")
                pickle.dump(sid_inf, fd)
                fd.close()
                return 0
            except:
                return 1
        else:
            return 2

    def serv_get_methods(self, client_type):
        curThread = threading.currentThread()
        certificate = curThread.client_cert
        from cert_cmd import find_cert_id
        cert_id = find_cert_id (certificate, self.data_path, self.certbase)

        rights = self.serv_view_cert_right(cert_id, self.data_path,client_type)
        return_list = []
        if client_type == "console":
            for meth in self.return_conMethod():
                right_flag = True
                for right in Dec.rightsMethods[meth[1]]:
                    if not right in rights:
                        right_flag = False
                if right_flag:
                    return_list.append(meth)
            if not len(return_list):
                return [['0','0']]
            return return_list
        else:
            curThread = threading.currentThread()
            for meth in self.return_guiMethod():
                right_flag = True
                for right in Dec.rightsMethods[meth[1]]:
                    if not right in rights:
                        right_flag = False
                if right_flag:
                    return_list.append(meth)
            if not len(return_list):
                return [['0','0']]
            return return_list
            #return self.return_guiMethod()

    # return a list of methods for the console as list
    def return_conMethod(self):
        from decorators import Dec
        results = []
        for item in Dec.conMethods:
            temp = []
            temp.append(item)
            for i in Dec.conMethods[item]:
                temp.append(i)
            results.append (temp)
        return results

    # return a list of methods for the GUI as list
    def return_guiMethod(self):
        from decorators import Dec
        results = []
        for item in Dec.guiMethods:
            for i in range(0, len(Dec.guiMethods[item]),3):
                temp = []
                temp.append(item)
                for j in range (3):
                    temp.append(Dec.guiMethods[item][i+j])
                results.append (temp)
        return results

    # get available sessions
    def serv_get_sessions(self):
        result = []
        fd = open(self.sids_file, 'r')
        while 1:
            try:
                # read all on one record
                list_sid = pickle.load(fd)
            except:
                break
            # if session id found
            result.append (str(list_sid[0]))
        fd.close()
        return result

    # check client alive
    def client_alive(sid, SIDS_DIR):
        sid_path = SIDS_DIR + "/%d.sid"%sid
        if not os.path.isfile(sid_path):
            return 1
        with open(sid_path) as fd:
            # read information about session
            sid_inf = pickle.load(fd)
            # flag absence client
        fd.close()
        if sid_inf[2] == 1:
            return 0
        else:
            return 1

    class Common:
        """ class to interact with the processes """
        def __init__(self, process_dict, progress_dict, table_dict,
                     frame_list, pid):
            self.process_dict = process_dict
            self.progress_dict = progress_dict
            self.progress_dict['id'] = 0
            self.table_dict = table_dict
            self.frame_list = frame_list
            self.pid = pid
            self.Num = 100000

        def writeFile(self):
            """ write data in file """
            from baseClass import Basic
            if not os.path.exists(Basic.pids):
                os.system('mkdir %s' %Basic.pids)
            self.PID_FILE = Basic.pids + '/%d.pid'%self.pid
            try:
                _fc = open(self.PID_FILE,"w")
                pickle.dump(self.process_dict, _fc)
                _fc.close()
            except:
                print _("Failed to read the PID file %s!") %self.PID_FILE

        def setProgress(self, perc, short_message = None, long_message = None):
            id = self.progress_dict['id']
            self.progress_dict[id] = ReturnProgress(perc, short_message, \
                                                    long_message)

        def setStatus(self, stat):
            self.process_dict['status'] = stat

        def setData(self, dat):
            self.data_list = dat

        def getStatus(self):
            return self.process_dict['status']

        def getProgress(self):
            id = self.progress_dict['id']
            if self.progress_dict.has_key(id):
                return self.progress_dict[id].percent
            return 0

        def getAnswer(self):
            import time
            while self.process_dict['answer'] == None:
                time.sleep (2)
            res = self.process_dict['answer']
            self.process_dict['answer'] = None
            self.frame_list.pop(len(self.frame_list) - 1)
            self.process_dict['counter'] -= 1            
            return res

        def addProgress(self):
            id = random.randint(1, self.Num)
            while id in self.progress_dict:
                id = random.randint(1, self.Num)
            self.progress_dict['id'] = id
            self.progress_dict[id] = ReturnProgress(0, '', '')
            self.addMessage(type = 'progress', id = id)

        def printTable (self, table_name, head, body, fields = None,\
                        onClick = None, addAction = None):
            id = random.randint(1, self.Num)
            while id in self.table_dict:
                id = random.randint(1, self.Num)

            from api_types import Table
            table = Table(head = head, body = map(lambda x:map(str,x),body), fields = fields, \
                          onClick = onClick, addAction = addAction, \
                          values = None)
            self.table_dict[id] = table
            self.addMessage(type = 'table', message = table_name, id = id)

        def addMessage(self, type = 'normal', message = None, id = None):
            from api_types import Message
            message = Message(type = type, message = message, id = id)
            self.frame_list.append(message)

        def printSUCCESS(self, message = ''):
            self.addMessage(type = 'normal', message = message)
        
        def printDefault(self, message = ''):
            self.addMessage(type = 'plain', message = message)

        def printWARNING(self, message):
            self.addMessage(type = 'warning', message = message)

        def printERROR(self, message = ''):
            perc = self.getProgress()
            if perc == 0:
                self.setProgress(100)
            elif self.getProgress() > 0:
                self.setProgress(0 - self.getProgress())
            else:
                #self.setProgress(-100)
                self.setProgress(perc)
            self.addMessage(type = 'error', message = message)

        def startTask(self, message, progress = False, num = 1):
            if progress:
                self.addMessage(type = 'startTask', message = message, id=num)
                self.addProgress()
            else:
                self.addMessage(type = 'startTask', message = message, id=num)

        def setTaskNumber(self, number = None):
            self.addMessage(type = 'taskNumber', message = str(number))

        def endTask(self, result = None, progress_message = None):
            self.addMessage(type = 'endTask', message = result)
            self.setProgress(100, progress_message)

        def askQuestion(self, message):
            self.addMessage(type = 'question', message = message)
            return self.getAnswer()

        def askPassword(self, message, twice = False):
            pas_repeat = 2 if twice else 1
            self.addMessage(type = 'password', message = message, \
                            id = pas_repeat)
            return self.getAnswer()

        def beginFrame(self, message = None):
            self.addMessage(type = 'beginFrame', message = message)

        def endFrame(self):
            self.addMessage(type = 'endFrame')

        def startGroup(self, message):
            self.addMessage(type = 'startGroup', message = message)

        def endGruop(self):
            self.addMessage(type = 'endGruop')

        def briefParams(self, view_name):
            self.addMessage(type = 'briefParams', message = view_name)

        #def cache(self, param):
            #sid = self.process_dict['sid']
            #self.args[sid] = collections.OrderedDict()

    def startprocess (self, sid, target=None, method=None, method_name=None, \
                      auto_delete=False, args_proc = {}):
        """ start process """
        pid = self.gen_pid()
        self.add_sid_pid(sid, pid)

        import multiprocessing
        if self.manager is None:
            self.__class__.manager = multiprocessing.Manager()
        # Manager for sending glob_process_dict between watcher and process
        #manager = multiprocessing.Manager()
        self.glob_process_dict[pid] = self.manager.dict()
        self.glob_process_dict[pid]['sid'] = sid
        self.glob_process_dict[pid]['status'] = 0
        self.glob_process_dict[pid]['time'] = ""
        self.glob_process_dict[pid]['answer'] = None
        self.glob_process_dict[pid]['name'] = ""
        self.glob_process_dict[pid]['flag'] = 0
        self.glob_process_dict[pid]['counter'] = 0

        self.glob_frame_list[pid] = self.manager.list()
        self.glob_progress_dict[pid] = self.manager.dict()
        self.glob_table_dict[pid] = self.manager.dict()

        # create object Common and send parameters
        com = target(self.glob_process_dict[pid], \
                     self.glob_progress_dict[pid], \
                     self.glob_table_dict[pid], \
                     self.glob_frame_list[pid], pid)

        if hasattr (com.__class__.__bases__[1], '__init__'):
            com.__class__.__bases__[1].__init__(com)
        # start helper
        p = multiprocessing.Process(target = self.target_helper,\
                args = (com, getattr(com,method)) +(method_name, )+ args_proc)

        self.process_pid[pid] = p
        p.start()
        if auto_delete:
            # start watcher (for kill process on signal)
            watcher = threading.Thread(target = self.watcher_pid_proc,\
                        args = (sid, pid))

            watcher.start()
        return str(pid)

    # wrap all method
    def target_helper(self, com, target_proc, method_name, *args_proc):
        if not os.path.exists(self.pids):
            os.system('mkdir %s' %self.pids)
#        PID_FILE  =  self.pids + '/%d.pid'%com.pid
        import datetime
        dat = datetime.datetime.now()

        com.process_dict['status'] = 1
        com.process_dict['time'] = dat
        #if method_name:
        com.process_dict['method_name'] = method_name
        com.process_dict['name'] = target_proc.__func__.__name__

        try:
            result = target_proc(*args_proc)
        except:
            result = False
            print_exc()
            fd = open(self.log_filename,'a')
            print_exc(file=fd)
            fd.close()
        try:
            if result == True:
                com.setStatus (0)
                com.writeFile()
            elif result == False:
                if com.getStatus() == 1:
                    com.setStatus (2)
                com.writeFile()
            else:
                if com.getStatus() == 1:
                    com.setStatus (2)
                else:
                    com.setStatus (0)
                com.writeFile()
            try:
                if com.getProgress() < 100 and com.getProgress() > 0:
                    com.setProgress(0 - com.getProgress())
            except:
                pass

            if len(com.frame_list):
                last_message = com.frame_list[len(com.frame_list)-1]
                if last_message.type != 'endFrame':
                    com.endFrame()
            else:
                com.endFrame()
        except Exception:
            print_exc()
            fd = open(self.log_filename,'a')
            print_exc(file=fd)
            fd.close()
            com.endFrame()

    def serv_view_cert_right (self, cert_id, data_path, client_type = None):
        """ rights for the selected certificate """
        try:
            cert_id = int(cert_id)
        except:
            return ["-2"]
        cert_file =  data_path+'/client_certs/%s.crt' %str(cert_id)
        if not os.path.exists(cert_file):
            return ["-1"]
        cert = open(cert_file, 'r').read()

        #try:
        import OpenSSL
        certobj = OpenSSL.crypto.load_certificate \
                                            (OpenSSL.SSL.FILETYPE_PEM, cert)
        com = certobj.get_extension(certobj.get_extension_count()-1).get_data()
        groups = com.split(':')[1]
        groups_list = groups.split(',')
        #except:
            #return ['-1']
        results = []
        find_flag = False
        # if group = all and not redefined group all
        if 'all' in groups_list:
            fd = open(self.group_rights, 'r')
            t = fd.read()
            # find all in group_rights file
            for line in t.splitlines():
                if not line:
                    continue
                if line.split()[0] == 'all':
                    find_flag = True
                    break
            if not find_flag:
                result = []
                if client_type == 'console':
                    for meth_list in self.return_conMethod():
                        for right in Dec.rightsMethods[meth_list[1]]:
                            result.append(right)
                else:
                    for meth_list in self.return_guiMethod():
                        for right in Dec.rightsMethods[meth_list[1]]:
                            result.append(right)
                result = uniq(result)
                results = result

        if not 'all' in groups_list or find_flag:
            if not os.path.exists (self.group_rights):
                open(self.group_rights, 'w')
            with open(self.group_rights) as fd:
                t = fd.read()
                for line in t.splitlines():
                    if not line:
                        continue
                    try:
                        words = line.split(' ',1)
                        if len(words) < 2:
                            continue
                        # first word in line equal name input method
                        if words[0] in groups_list:
                            methods = words[1].split(',')
                            for i in methods:
                                results.append(i.strip())
                    except IndexError:
                        print 'except IndexError in serv_view_cert_right'
                        continue
            results = uniq(results)

        add_list_rights = []
        del_list_rights = []

        with open(self.rights) as fr:
            t = fr.read()
            for line in t.splitlines():
                words = line.split()
                meth = words[0]
                for word in words:
                    try:
                        word = int(word)
                    except:
                        continue
                    # compare with certificat number
                    if cert_id == word:
                        # if has right
                        add_list_rights.append(meth)
                    if cert_id == -word:
                        del_list_rights.append(meth)

        results += add_list_rights
        results = uniq(results)

        for method in results:
            if method in del_list_rights:
                results.remove(method)  

        if results == []:
            results.append("No Methods")
        return results

    def get_lang(self, sid):
        """ get clients lang """
        lang = None
        SIDS_DIR = self.sids
        sid_file = SIDS_DIR+"/%d.sid" %int(sid)
        if os.path.exists(sid_file):
#            temp = open(sid_file, 'w')
#            temp.close()
            fd = open(sid_file, 'r')
            while 1:
                try:
                    list_sid = pickle.load(fd)
                except:
                    break
                # if session id found
                if sid == list_sid[0]:
                    fd.close()
                    lang = list_sid[3]
            fd.close()
        try:
            if not lang.lower() in ['fr', 'ru', 'en']:
                lang = "en"
        except:
                lang = "en"
        #elif lang == "":
            #lang = threading.currentThread().lang
        import locale
        try:
            lang = locale.locale_alias[lang.lower()]
        except:
            lang = locale.locale_alias['en']
        return lang

def create_symlink(data_path,old_data_path):
    meths = Dec.conMethods
    path_to_link = '/usr/sbin'
    path_to_user_link = '/usr/bin'
    old_symlinks_file = os.path.join(old_data_path, 'conf/symlinks')
    symlinks_file = os.path.join(data_path, 'conf/symlinks')
    if not os.path.exists(os.path.join(data_path, 'conf')):
        try:
            os.makedirs(os.path.join(data_path, 'conf'))
        except OSError:
            print _("cannot create directory %s") \
                                %(os.path.join(data_path, 'conf'))
    if os.path.exists(old_symlinks_file) and not os.path.exists(symlinks_file):
        open(symlinks_file,'w').write(open(old_symlinks_file).read())
        os.unlink(old_symlinks_file)
    fd = open(symlinks_file, 'a')
    for link in meths:
        if meths[link][1]:
            link_path = os.path.join(path_to_user_link, link)
        else:
            link_path = os.path.join(path_to_link, link)
        if os.path.islink(link_path):
            continue
        if os.path.isfile(link_path):
            red = '\033[31m * \033[0m'
            print red+link_path+_(' is a file, not a link!')
            continue
        try:
            os.symlink(os.path.join(path_to_link, 'cl-core'), link_path)
            fd.write(link_path + '\n')
        except OSError, e:
            print e.message
        print _('Symlink %s created') %link_path
    fd.close()

    from calculate.lib.utils.files import readLinesFile
    temp_text_file = ''
    for line in readLinesFile(symlinks_file):
        cmdname = os.path.basename(line)
        if not cmdname in meths.keys() or \
            line.startswith(path_to_link) and meths[cmdname][1] or \
            line.startswith(path_to_user_link) and not meths[cmdname][1]:
            if os.path.islink(line):
                os.unlink(line)
                print _('Symlink %s deleted') %line
        else:
            temp_text_file += line + '\n'
    fd = open(symlinks_file, 'w')
    fd.write(temp_text_file)
    fd.close()

def initialization(cl_wsdl):
    """ find modules for further added in server class """
    cl_apis = []
    for pack in cl_wsdl:
        if pack:
            module_name = '%s.cl_wsdl_%s'% (pack.replace("-","."),
                                            pack.rpartition("-")[2])
            #try:
            import importlib
            cl_wsdl_core = importlib.import_module(module_name)
            try:
                cl_apis.append (cl_wsdl_core.Wsdl)
            except ImportError:
                sys.stderr.write(_("Unable to import %s")%module_name)
    return cl_apis

#Creation of secret key of the client
def new_key_req(key, cert_path, serv_host_name, port):
    from create_cert import generateRSAKey, makePKey, makeRequest,\
                                    passphrase_callback
    rsa = generateRSAKey()
    rsa.save_key(key+'_pub',\
                        cipher=None, callback=passphrase_callback)

    pkey = makePKey(rsa)
    pkey.save_key(key,\
                        cipher=None, callback=passphrase_callback)

    req = makeRequest(rsa, pkey, serv_host_name, port)
    if not req:
        sys.exit()
    crtreq = req.as_pem()
    crtfile = open(cert_path + '/server.csr', 'w')
    crtfile.write(crtreq)
    crtfile.close()

# delete dublicate from list
def uniq(seq):
    seen = set()
    seen_add = seen.add
    return [ x for x in seq if x not in seen and not seen_add(x)]
