diff --git a/backend/core/admin.py b/backend/core/admin.py index 9c0f3bfa..20f4947b 100644 --- a/backend/core/admin.py +++ b/backend/core/admin.py @@ -51,8 +51,8 @@ class FeedbackAdmin(geoadmin.OSMGeoAdmin): @admin.register(Banner) class BannerAdmin(admin.ModelAdmin): - list_display = ("message", "start_date", "end_date", "is_active", "is_displayable") - list_filter = ("is_active", "start_date", "end_date") + list_display = ("message", "start_date", "end_date", "is_displayable") + list_filter = ("start_date", "end_date") search_fields = ("message",) readonly_fields = ("is_displayable",) diff --git a/backend/core/models.py b/backend/core/models.py index 20ac0776..3964587d 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -47,7 +47,7 @@ class Label(models.Model): class Model(models.Model): - FOUNDATION_MODEL_CHOICES = ( + BASE_MODEL_CHOICES = ( ("RAMP", "RAMP"), ("YOLO", "YOLO"), ) @@ -65,8 +65,8 @@ class ModelStatus(models.IntegerChoices): user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) published_training = models.PositiveIntegerField(null=True, blank=True) status = models.IntegerField(default=-1, choices=ModelStatus.choices) - foundation_model = models.CharField( - choices=FOUNDATION_MODEL_CHOICES, default="RAMP", max_length=10 + base_model = models.CharField( + choices=BASE_MODEL_CHOICES, default="RAMP", max_length=10 ) @@ -159,17 +159,14 @@ class ApprovedPredictions(models.Model): class Banner(models.Model): - message = models.CharField(max_length=255) + message = models.TextField() start_date = models.DateTimeField(default=timezone.now) end_date = models.DateTimeField(null=True, blank=True) - is_active = models.BooleanField(default=True) def is_displayable(self): now = timezone.now() - return ( - self.is_active - and (self.start_date <= now) - and (self.end_date is None or self.end_date >= now) + return (self.start_date <= now) and ( + self.end_date is None or self.end_date >= now ) def __str__(self): diff --git a/backend/core/serializers.py b/backend/core/serializers.py index ec2a22d8..bffdc1d9 100644 --- a/backend/core/serializers.py +++ b/backend/core/serializers.py @@ -403,7 +403,6 @@ class Meta: "message", "start_date", "end_date", - "is_active", "is_displayable", ] read_only_fields = ["is_displayable"] diff --git a/backend/core/views.py b/backend/core/views.py index 8722ca3f..6ce1960f 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -724,16 +724,24 @@ def post(self, request, *args, **kwargs): def publish_training(request, training_id: int): """Publishes training for model""" training_instance = get_object_or_404(Training, id=training_id) + if training_instance.status != "FINISHED": return Response("Training is not FINISHED", status=404) if training_instance.accuracy < 70: return Response( - "Can't publish the training since it's accuracy is below 70 %", status=404 + "Can't publish the training since its accuracy is below 70%", status=404 ) + model_instance = get_object_or_404(Model, id=training_instance.model.id) + + # Check if the current user is the owner of the model + if model_instance.user != request.user: + return Response("You are not allowed to publish this training", status=403) + model_instance.published_training = training_instance.id model_instance.status = 0 model_instance.save() + return Response("Training Published", status=status.HTTP_201_CREATED) @@ -800,8 +808,8 @@ def get(self, request, lookup_dir=None): class TrainingWorkspaceDownloadView(APIView): - # authentication_classes = [OsmAuthentication] - # permission_classes = [IsOsmAuthenticated] + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] def get(self, request, lookup_dir): base_dir = os.path.join(settings.TRAINING_WORKSPACE, lookup_dir) @@ -847,7 +855,8 @@ class BannerViewSet(viewsets.ModelViewSet): queryset = Banner.objects.all() serializer_class = BannerSerializer authentication_classes = [OsmAuthentication] - permission_classes = [IsStaffUser, IsAdminUser] + permission_classes = [IsAdminUser, IsStaffUser] + public_methods = ["GET"] def get_queryset(self): now = timezone.now() diff --git a/backend/login/admin.py b/backend/login/admin.py index 72147c90..566991b1 100644 --- a/backend/login/admin.py +++ b/backend/login/admin.py @@ -86,6 +86,8 @@ class OsmUserAdmin(admin.ModelAdmin): }, ), ) - formfield_overrides = { - models.CharField: {"validators": []}, - } + + def formfield_for_dbfield(self, db_field, request, **kwargs): + if db_field.name == "username": + kwargs["validators"] = [] ## override the validation for sername + return super().formfield_for_dbfield(db_field, request, **kwargs) diff --git a/backend/login/permissions.py b/backend/login/permissions.py index bc76ebad..1ff82374 100644 --- a/backend/login/permissions.py +++ b/backend/login/permissions.py @@ -10,9 +10,9 @@ def has_permission(self, request, view): public_methods = getattr(view, "public_methods", []) if request.method in public_methods: return True - # If the user is authenticated, allow access + if request.user and request.user.is_authenticated: - # Global access for staff and admin users + # Global access if request.user.is_staff or request.user.is_superuser: return True @@ -21,7 +21,7 @@ def has_permission(self, request, view): return False def has_object_permission(self, request, view, obj): - # Allow read-only access for any authenticated user + if request.method in permissions.SAFE_METHODS: return True @@ -29,8 +29,7 @@ def has_object_permission(self, request, view, obj): if request.user.is_staff or request.user.is_superuser: return True - # Check if the object has a 'creator' field and if the user is the creator - if hasattr(obj, "creator") and obj.creator == request.user: + if hasattr(obj, "user") and obj.user == request.user: return True return False