]> git.ozlabs.org Git - patchwork/blobdiff - apps/patchwork/models.py
Simplify hashlib behaviour in HashField
[patchwork] / apps / patchwork / models.py
index fb2ccc7a298aaba21496dcf644dc916b4eacb81e..62ce59266b8a445a5c3a1516dd1fd4d83335b451 100644 (file)
@@ -22,6 +22,7 @@ from django.contrib.auth.models import User
 from django.core.urlresolvers import reverse
 from django.contrib.sites.models import Site
 from django.conf import settings
+from patchwork.parser import hash_patch
 import django.oldforms as oldforms
 
 import re
@@ -168,7 +169,7 @@ class State(models.Model):
     class Meta:
         ordering = ['ordering']
 
-class HashField(models.Field):
+class HashField(models.CharField):
     __metaclass__ = models.SubfieldBase
 
     def __init__(self, algorithm = 'sha1', *args, **kwargs):
@@ -176,40 +177,24 @@ class HashField(models.Field):
         try:
             import hashlib
             self.hashlib = True
+            self.n_bytes = len(hashlib.new(self.algorithm).hexdigest())
         except ImportError:
             self.hashlib = False
             if algorithm == 'sha1':
                 import sha
-                self.hash_constructor = sha.new
+                hash_constructor = sha.new
             elif algorithm == 'md5':
                 import md5
-                self.hash_constructor = md5.new
+                hash_constructor = md5.new
             else:
                 raise NameError("Unknown algorithm '%s'" % algorithm)
-            
+            self.n_bytes = len(hash_constructor().hexdigest())
+
+        kwargs['max_length'] = self.n_bytes
         super(HashField, self).__init__(*args, **kwargs)
 
     def db_type(self):
-        if self.hashlib:
-            n_bytes = len(hashlib.new(self.algorithm).digest())
-        else:
-            n_bytes = len(self.hash_constructor().digest())
-        if settings.DATABASE_ENGINE.startswith('postgresql'):
-            return 'bytea'
-        elif settings.DATABASE_ENGINE == 'mysql':
-            return 'binary(%d)' % n_bytes
-        else:
-            raise Exception("Unknown database engine '%s'" % \
-                            settings.DATABASE_ENGINE)
-
-    def to_python(self, value):
-        return value
-
-    def get_db_prep_save(self, value):
-        return ''.join(map(lambda x: '\\%03o' % ord(x), value))
-
-    def get_manipulator_field_objs(self):
-        return [oldforms.TextField]
+        return 'char(%d)' % self.n_bytes
 
 class Patch(models.Model):
     project = models.ForeignKey(Project)
@@ -223,7 +208,7 @@ class Patch(models.Model):
     headers = models.TextField(blank = True)
     content = models.TextField()
     commit_ref = models.CharField(max_length=255, null = True, blank = True)
-    hash = HashField()
+    hash = HashField(null = True, db_index = True)
 
     def __str__(self):
         return self.name
@@ -236,8 +221,10 @@ class Patch(models.Model):
             s = self.state
         except:
             self.state = State.objects.get(ordering =  0)
-        if hash is None:
-            print "no hash"
+
+        if self.hash is None:
+            self.hash = hash_patch(self.content).hexdigest()
+
         super(Patch, self).save()
 
     def is_editable(self, user):