diff --git a/rhodecode/model/hg_model.py b/rhodecode/model/scm.py copy from rhodecode/model/hg_model.py copy to rhodecode/model/scm.py --- a/rhodecode/model/hg_model.py +++ b/rhodecode/model/scm.py @@ -1,8 +1,15 @@ -#!/usr/bin/env python -# encoding: utf-8 -# Model for RhodeCode -# Copyright (C) 2009-2010 Marcin Kuzminski -# +# -*- coding: utf-8 -*- +""" + rhodecode.model.scm + ~~~~~~~~~~~~~~~~~~~ + + Scm model for RhodeCode + + :created_on: Apr 9, 2010 + :author: marcink + :copyright: (C) 2009-2010 Marcin Kuzminski + :license: GPLv3, see COPYING for more details. +""" # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; version 2 @@ -17,170 +24,354 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. -""" -Created on April 9, 2010 -Model for RhodeCode -@author: marcink -""" -from beaker.cache import cache_region +import os +import time +import traceback +import logging + +from vcs import get_backend +from vcs.utils.helpers import get_scm +from vcs.exceptions import RepositoryError, VCSError +from vcs.utils.lazy import LazyProperty + from mercurial import ui -from mercurial.hgweb.hgwebdir_mod import findrepos -from pylons.i18n.translation import _ + +from beaker.cache import cache_region, region_invalidate + +from rhodecode import BACKENDS from rhodecode.lib import helpers as h -from rhodecode.lib.utils import invalidate_cache from rhodecode.lib.auth import HasRepoPermissionAny -from rhodecode.model import meta -from rhodecode.model.db import Repository, User +from rhodecode.lib.utils import get_repos, make_ui, action_logger +from rhodecode.model import BaseModel +from rhodecode.model.user import UserModel + +from rhodecode.model.db import Repository, RhodeCodeUi, CacheInvalidation, \ + UserFollowing, UserLog +from rhodecode.model.caching_query import FromCache + from sqlalchemy.orm import joinedload -from vcs.exceptions import RepositoryError, VCSError -import logging -import os -import sys +from sqlalchemy.orm.session import make_transient +from sqlalchemy.exc import DatabaseError + log = logging.getLogger(__name__) -try: - from vcs.backends.hg import MercurialRepository -except ImportError: - sys.stderr.write('You have to import vcs module') - raise Exception('Unable to import vcs') - -def _get_repos_cached_initial(app_globals, initial): - """return cached dict with repos - """ - g = app_globals - return HgModel.repo_scan(g.paths[0][0], g.paths[0][1], g.baseui, initial) -@cache_region('long_term', 'cached_repo_list') -def _get_repos_cached(): - """return cached dict with repos - """ - log.info('getting all repositories list') - from pylons import app_globals as g - return HgModel.repo_scan(g.paths[0][0], g.paths[0][1], g.baseui) +class UserTemp(object): + def __init__(self, user_id): + self.user_id = user_id +class RepoTemp(object): + def __init__(self, repo_id): + self.repo_id = repo_id -@cache_region('super_short_term', 'cached_repos_switcher_list') -def _get_repos_switcher_cached(cached_repo_list): - repos_lst = [] - for repo in [x for x in cached_repo_list.values()]: - if HasRepoPermissionAny('repository.write', 'repository.read', - 'repository.admin')(repo.name, 'main page check'): - repos_lst.append((repo.name, repo.dbrepo.private,)) - - return sorted(repos_lst, key=lambda k:k[0].lower()) - -@cache_region('long_term', 'full_changelog') -def _full_changelog_cached(repo_name): - log.info('getting full changelog for %s', repo_name) - return list(reversed(list(HgModel().get_repo(repo_name)))) - -class HgModel(object): - """Mercurial Model +class ScmModel(BaseModel): + """Generic Scm Model """ - def __init__(self): - pass - - @staticmethod - def repo_scan(repos_prefix, repos_path, baseui, initial=False): - """ - Listing of repositories in given path. This path should not be a - repository itself. Return a dictionary of repository objects - :param repos_path: path to directory it could take syntax with - * or ** for deep recursive displaying repositories + @LazyProperty + def repos_path(self): + """Get's the repositories root path from database """ - sa = meta.Session() - def check_repo_dir(path): - """Checks the repository - :param path: - """ - repos_path = path.split('/') - if repos_path[-1] in ['*', '**']: - repos_path = repos_path[:-1] - if repos_path[0] != '/': - repos_path[0] = '/' - if not os.path.isdir(os.path.join(*repos_path)): - raise RepositoryError('Not a valid repository in %s' % path) - if not repos_path.endswith('*'): - raise VCSError('You need to specify * or ** at the end of path ' - 'for recursive scanning') - - check_repo_dir(repos_path) + + q = self.sa.query(RhodeCodeUi).filter(RhodeCodeUi.ui_key == '/').one() + + return q.ui_value + + def repo_scan(self, repos_path, baseui): + """Listing of repositories in given path. This path should not be a + repository itself. Return a dictionary of repository objects + + :param repos_path: path to directory containing repositories + :param baseui: baseui instance to instantiate MercurialRepostitory with + """ + log.info('scanning for repositories in %s', repos_path) - repos = findrepos([(repos_prefix, repos_path)]) + if not isinstance(baseui, ui.ui): - baseui = ui.ui() - + baseui = make_ui('db') repos_list = {} - for name, path in repos: + + for name, path in get_repos(repos_path): try: - #name = name.split('/')[-1] if repos_list.has_key(name): - raise RepositoryError('Duplicate repository name %s found in' - ' %s' % (name, path)) + raise RepositoryError('Duplicate repository name %s ' + 'found in %s' % (name, path)) else: - - repos_list[name] = MercurialRepository(path, baseui=baseui) - repos_list[name].name = name - - dbrepo = None - if not initial: - #for initial scann on application first run we don't - #have db repos yet. - dbrepo = sa.query(Repository)\ - .options(joinedload(Repository.fork))\ - .filter(Repository.repo_name == name)\ - .scalar() - - if dbrepo: - log.info('Adding db instance to cached list') - repos_list[name].dbrepo = dbrepo - repos_list[name].description = dbrepo.description - if dbrepo.user: - repos_list[name].contact = dbrepo.user.full_contact - else: - repos_list[name].contact = sa.query(User)\ - .filter(User.admin == True).first().full_contact + + klass = get_backend(path[0]) + + if path[0] == 'hg' and path[0] in BACKENDS.keys(): + repos_list[name] = klass(path[1], baseui=baseui) + + if path[0] == 'git' and path[0] in BACKENDS.keys(): + repos_list[name] = klass(path[1]) except OSError: continue - meta.Session.remove() + return repos_list + + def get_repos(self, all_repos=None): + """Get all repos from db and for each repo create it's backend instance. + and fill that backed with information from database - def get_repos(self): - for name, repo in _get_repos_cached().items(): - if repo._get_hidden(): - #skip hidden web repository - continue - - last_change = repo.last_change - tip = h.get_changeset_safe(repo, 'tip') - - tmp_d = {} - tmp_d['name'] = repo.name - tmp_d['name_sort'] = tmp_d['name'].lower() - tmp_d['description'] = repo.description - tmp_d['description_sort'] = tmp_d['description'] - tmp_d['last_change'] = last_change - tmp_d['last_change_sort'] = last_change[1] - last_change[0] - tmp_d['tip'] = tip.short_id - tmp_d['tip_sort'] = tip.revision - tmp_d['rev'] = tip.revision - tmp_d['contact'] = repo.contact - tmp_d['contact_sort'] = tmp_d['contact'] - tmp_d['repo_archives'] = list(repo._get_archives()) - tmp_d['last_msg'] = tip.message - tmp_d['repo'] = repo - yield tmp_d + :param all_repos: give specific repositories list, good for filtering + """ + + if all_repos is None: + all_repos = self.sa.query(Repository)\ + .order_by(Repository.repo_name).all() + + #get the repositories that should be invalidated + invalidation_list = [str(x.cache_key) for x in \ + self.sa.query(CacheInvalidation.cache_key)\ + .filter(CacheInvalidation.cache_active == False)\ + .all()] + + for r in all_repos: + + repo = self.get(r.repo_name, invalidation_list) + + if repo is not None: + last_change = repo.last_change + tip = h.get_changeset_safe(repo, 'tip') + + tmp_d = {} + tmp_d['name'] = repo.name + tmp_d['name_sort'] = tmp_d['name'].lower() + tmp_d['description'] = repo.dbrepo.description + tmp_d['description_sort'] = tmp_d['description'] + tmp_d['last_change'] = last_change + tmp_d['last_change_sort'] = time.mktime(last_change.timetuple()) + tmp_d['tip'] = tip.raw_id + tmp_d['tip_sort'] = tip.revision + tmp_d['rev'] = tip.revision + tmp_d['contact'] = repo.dbrepo.user.full_contact + tmp_d['contact_sort'] = tmp_d['contact'] + tmp_d['repo_archives'] = list(repo._get_archives()) + tmp_d['last_msg'] = tip.message + tmp_d['repo'] = repo + yield tmp_d def get_repo(self, repo_name): - try: - repo = _get_repos_cached()[repo_name] + return self.get(repo_name) + + def get(self, repo_name, invalidation_list=None): + """Get's repository from given name, creates BackendInstance and + propagates it's data from database with all additional information + + :param repo_name: + :param invalidation_list: if a invalidation list is given the get + method should not manually check if this repository needs + invalidation and just invalidate the repositories in list + + """ + if not HasRepoPermissionAny('repository.read', 'repository.write', + 'repository.admin')(repo_name, 'get repo check'): + return + + #====================================================================== + # CACHE FUNCTION + #====================================================================== + @cache_region('long_term') + def _get_repo(repo_name): + + repo_path = os.path.join(self.repos_path, repo_name) + + try: + alias = get_scm(repo_path)[0] + + log.debug('Creating instance of %s repository', alias) + backend = get_backend(alias) + except VCSError: + log.error(traceback.format_exc()) + return + + if alias == 'hg': + from pylons import app_globals as g + repo = backend(repo_path, create=False, baseui=g.baseui) + #skip hidden web repository + if repo._get_hidden(): + return + else: + repo = backend(repo_path, create=False) + + dbrepo = self.sa.query(Repository)\ + .options(joinedload(Repository.fork))\ + .options(joinedload(Repository.user))\ + .filter(Repository.repo_name == repo_name)\ + .scalar() + + make_transient(dbrepo) + if dbrepo.user: + make_transient(dbrepo.user) + if dbrepo.fork: + make_transient(dbrepo.fork) + + repo.dbrepo = dbrepo return repo - except KeyError: - #i we're here and we got key errors let's try to invalidate the - #cahce and try again - invalidate_cache('cached_repo_list') - repo = _get_repos_cached()[repo_name] - return repo - + + pre_invalidate = True + if invalidation_list is not None: + pre_invalidate = repo_name in invalidation_list + + if pre_invalidate: + invalidate = self._should_invalidate(repo_name) + + if invalidate: + log.info('invalidating cache for repository %s', repo_name) + region_invalidate(_get_repo, None, repo_name) + self._mark_invalidated(invalidate) + + return _get_repo(repo_name) + + + + def mark_for_invalidation(self, repo_name): + """Puts cache invalidation task into db for + further global cache invalidation + :param repo_name: this repo that should invalidation take place + """ + + log.debug('marking %s for invalidation', repo_name) + cache = self.sa.query(CacheInvalidation)\ + .filter(CacheInvalidation.cache_key == repo_name).scalar() + + if cache: + #mark this cache as inactive + cache.cache_active = False + else: + log.debug('cache key not found in invalidation db -> creating one') + cache = CacheInvalidation(repo_name) + + try: + self.sa.add(cache) + self.sa.commit() + except (DatabaseError,): + log.error(traceback.format_exc()) + self.sa.rollback() + + + def toggle_following_repo(self, follow_repo_id, user_id): + + f = self.sa.query(UserFollowing)\ + .filter(UserFollowing.follows_repo_id == follow_repo_id)\ + .filter(UserFollowing.user_id == user_id).scalar() + + if f is not None: + + try: + self.sa.delete(f) + self.sa.commit() + action_logger(UserTemp(user_id), + 'stopped_following_repo', + RepoTemp(follow_repo_id)) + return + except: + log.error(traceback.format_exc()) + self.sa.rollback() + raise + + + try: + f = UserFollowing() + f.user_id = user_id + f.follows_repo_id = follow_repo_id + self.sa.add(f) + self.sa.commit() + action_logger(UserTemp(user_id), + 'started_following_repo', + RepoTemp(follow_repo_id)) + except: + log.error(traceback.format_exc()) + self.sa.rollback() + raise + + def toggle_following_user(self, follow_user_id , user_id): + f = self.sa.query(UserFollowing)\ + .filter(UserFollowing.follows_user_id == follow_user_id)\ + .filter(UserFollowing.user_id == user_id).scalar() + + if f is not None: + try: + self.sa.delete(f) + self.sa.commit() + return + except: + log.error(traceback.format_exc()) + self.sa.rollback() + raise + + try: + f = UserFollowing() + f.user_id = user_id + f.follows_user_id = follow_user_id + self.sa.add(f) + self.sa.commit() + except: + log.error(traceback.format_exc()) + self.sa.rollback() + raise + + def is_following_repo(self, repo_name, user_id): + r = self.sa.query(Repository)\ + .filter(Repository.repo_name == repo_name).scalar() + + f = self.sa.query(UserFollowing)\ + .filter(UserFollowing.follows_repository == r)\ + .filter(UserFollowing.user_id == user_id).scalar() + + return f is not None + + def is_following_user(self, username, user_id): + u = UserModel(self.sa).get_by_username(username) + + f = self.sa.query(UserFollowing)\ + .filter(UserFollowing.follows_user == u)\ + .filter(UserFollowing.user_id == user_id).scalar() + + return f is not None + + def get_followers(self, repo_id): + return self.sa.query(UserFollowing)\ + .filter(UserFollowing.follows_repo_id == repo_id).count() + + def get_forks(self, repo_id): + return self.sa.query(Repository)\ + .filter(Repository.fork_id == repo_id).count() + + + def get_unread_journal(self): + return self.sa.query(UserLog).count() + + + def _should_invalidate(self, repo_name): + """Looks up database for invalidation signals for this repo_name + :param repo_name: + """ + + ret = self.sa.query(CacheInvalidation)\ + .options(FromCache('sql_cache_short', + 'get_invalidation_%s' % repo_name))\ + .filter(CacheInvalidation.cache_key == repo_name)\ + .filter(CacheInvalidation.cache_active == False)\ + .scalar() + + return ret + + def _mark_invalidated(self, cache_key): + """ Marks all occurences of cache to invaldation as already invalidated + + :param cache_key: + """ + + if cache_key: + log.debug('marking %s as already invalidated', cache_key) + try: + cache_key.cache_active = True + self.sa.add(cache_key) + self.sa.commit() + except (DatabaseError,): + log.error(traceback.format_exc()) + self.sa.rollback() +