diff --git a/rhodecode/lib/dbmigrate/migrate/versioning/template.py b/rhodecode/lib/dbmigrate/migrate/versioning/template.py new file mode 100644 --- /dev/null +++ b/rhodecode/lib/dbmigrate/migrate/versioning/template.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import shutil +import sys + +from pkg_resources import resource_filename + +from migrate.versioning.config import * +from migrate.versioning import pathed + + +class Collection(pathed.Pathed): + """A collection of templates of a specific type""" + _mask = None + + def get_path(self, file): + return os.path.join(self.path, str(file)) + + +class RepositoryCollection(Collection): + _mask = '%s' + +class ScriptCollection(Collection): + _mask = '%s.py_tmpl' + +class ManageCollection(Collection): + _mask = '%s.py_tmpl' + +class SQLScriptCollection(Collection): + _mask = '%s.py_tmpl' + +class Template(pathed.Pathed): + """Finds the paths/packages of various Migrate templates. + + :param path: Templates are loaded from migrate package + if `path` is not provided. + """ + pkg = 'migrate.versioning.templates' + _manage = 'manage.py_tmpl' + + def __new__(cls, path=None): + if path is None: + path = cls._find_path(cls.pkg) + return super(Template, cls).__new__(cls, path) + + def __init__(self, path=None): + if path is None: + path = Template._find_path(self.pkg) + super(Template, self).__init__(path) + self.repository = RepositoryCollection(os.path.join(path, 'repository')) + self.script = ScriptCollection(os.path.join(path, 'script')) + self.manage = ManageCollection(os.path.join(path, 'manage')) + self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script')) + + @classmethod + def _find_path(cls, pkg): + """Returns absolute path to dotted python package.""" + tmp_pkg = pkg.rsplit('.', 1) + + if len(tmp_pkg) != 1: + return resource_filename(tmp_pkg[0], tmp_pkg[1]) + else: + return resource_filename(tmp_pkg[0], '') + + def _get_item(self, collection, theme=None): + """Locates and returns collection. + + :param collection: name of collection to locate + :param type_: type of subfolder in collection (defaults to "_default") + :returns: (package, source) + :rtype: str, str + """ + item = getattr(self, collection) + theme_mask = getattr(item, '_mask') + theme = theme_mask % (theme or 'default') + return item.get_path(theme) + + def get_repository(self, *a, **kw): + """Calls self._get_item('repository', *a, **kw)""" + return self._get_item('repository', *a, **kw) + + def get_script(self, *a, **kw): + """Calls self._get_item('script', *a, **kw)""" + return self._get_item('script', *a, **kw) + + def get_sql_script(self, *a, **kw): + """Calls self._get_item('sql_script', *a, **kw)""" + return self._get_item('sql_script', *a, **kw) + + def get_manage(self, *a, **kw): + """Calls self._get_item('manage', *a, **kw)""" + return self._get_item('manage', *a, **kw)