diff --git a/backend/aiproject/settings.py b/backend/aiproject/settings.py index 2c4cf2b7..106104eb 100644 --- a/backend/aiproject/settings.py +++ b/backend/aiproject/settings.py @@ -10,12 +10,13 @@ https://docs.aiproject.com/en/3.1/ref/settings/ """ -import os import logging +import os +from socket import gethostbyname, gethostname + import dj_database_url import environ from corsheaders.defaults import default_headers -from socket import gethostbyname, gethostname env = environ.Env() @@ -102,6 +103,7 @@ CORS_ORIGIN_WHITELIST = ALLOWED_ORIGINS CORS_ORIGIN_ALLOW_ALL = env("CORS_ORIGIN_ALLOW_ALL", default=False) +DEFAULT_PAGINATION_SIZE = env("DEFAULT_PAGINATION_SIZE", default=50) REST_FRAMEWORK = { "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.coreapi.AutoSchema", @@ -110,6 +112,8 @@ "rest_framework.authentication.BasicAuthentication", "login.authentication.OsmAuthentication", ], + "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination", + "PAGE_SIZE": DEFAULT_PAGINATION_SIZE, } ROOT_URLCONF = "aiproject.urls" @@ -208,7 +212,7 @@ } } # get ramp home and set it to environ -RAMP_HOME = env("RAMP_HOME",default=None) +RAMP_HOME = env("RAMP_HOME", default=None) if RAMP_HOME: os.environ["RAMP_HOME"] = RAMP_HOME @@ -220,4 +224,4 @@ ENABLE_PREDICTION_API = env("ENABLE_PREDICTION_API", default=False) -TEST_RUNNER = 'tests.test_runners.NoDestroyTestRunner' +TEST_RUNNER = "tests.test_runners.NoDestroyTestRunner" diff --git a/backend/core/models.py b/backend/core/models.py index 7ad9afd6..60247ff9 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -56,6 +56,7 @@ class ModelStatus(models.IntegerChoices): name = models.CharField(max_length=255) created_at = models.DateTimeField(auto_now_add=True) last_modified = models.DateTimeField(auto_now=True) + description = models.TextField(max_length=500, null=True, blank=True) created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) published_training = models.PositiveIntegerField(null=True, blank=True) status = models.IntegerField(default=-1, choices=ModelStatus.choices) # diff --git a/backend/core/serializers.py b/backend/core/serializers.py index 7465bcf2..3f4cc556 100644 --- a/backend/core/serializers.py +++ b/backend/core/serializers.py @@ -29,12 +29,31 @@ def create(self, validated_data): return super().create(validated_data) +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = OsmUser + fields = [ + "osm_id", + "username", + # "is_superuser", + # "is_active", + # "is_staff", + "date_joined", + # "email", + "img_url", + # "user_permissions", + ] + + class ModelSerializer( serializers.ModelSerializer ): # serializers are used to translate models objects to api + created_by = UserSerializer(read_only=True) + accuracy = serializers.SerializerMethodField() + class Meta: model = Model - fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want + fields = "__all__" read_only_fields = ( "created_at", "last_modified", @@ -47,6 +66,44 @@ def create(self, validated_data): validated_data["created_by"] = user return super().create(validated_data) + def get_accuracy( + self, obj + ): ## this might have performance problem when db grows bigger , consider adding indexes / view in db + training = Training.objects.filter(id=obj.published_training).first() + if training: + return training.accuracy + return None + + +class ModelCentroidSerializer(GeoFeatureModelSerializer): + geometry = serializers.SerializerMethodField() + mid = serializers.IntegerField(source="id") + + class Meta: + model = Model + geo_field = "geometry" + fields = ("mid", "name", "geometry") + + def get_geometry(self, obj): + """ + Get the centroid of the AOI linked to the dataset of the given model. + """ + aoi = AOI.objects.filter(dataset=obj.dataset).first() + if aoi and aoi.geom: + return { + "type": "Point", + "coordinates": aoi.geom.centroid.coords, + } + return None + + # def to_representation(self, instance): + # """ + # Override to_representation to customize GeoJSON structure. + # """ + # representation = super().to_representation(instance) + # representation["properties"]["id"] = representation.pop("id") + # return representation + class AOISerializer( GeoFeatureModelSerializer @@ -314,19 +371,3 @@ def validate(self, data): data["area_threshold"] ) return data - - -class UserSerializer(serializers.ModelSerializer): - class Meta: - model = OsmUser - fields = [ - "osm_id", - "username", - "is_superuser", - "is_active", - "is_staff", - "date_joined", - "email", - "img_url", - "user_permissions", - ] diff --git a/backend/core/urls.py b/backend/core/urls.py index 8212eade..b54bc169 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -16,12 +16,14 @@ GenerateFeedbackAOIGpxView, GenerateGpxView, LabelViewSet, + ModelCentroidView, ModelViewSet, RawdataApiAOIView, RawdataApiFeedbackView, TrainingViewSet, TrainingWorkspaceDownloadView, TrainingWorkspaceView, + UsersView, download_training_data, geojson2osmconverter, publish_training, @@ -51,6 +53,8 @@ "label/feedback/osm/fetch//", RawdataApiFeedbackView.as_view(), ), + path("users/", UsersView.as_view(), name="user-list-view"), + path("models/centroid/", ModelCentroidView.as_view(), name="model-centroid"), # path("download//", download_training_data), path("training/status//", run_task_status), path("training/publish//", publish_training), diff --git a/backend/core/views.py b/backend/core/views.py index d0bd9634..c89e5974 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -28,9 +28,10 @@ from geojson2osm import geojson2osm from orthogonalizer import othogonalize_poly from osmconflator import conflate_geojson -from rest_framework import decorators, serializers, status, viewsets +from rest_framework import decorators, filters, serializers, status, viewsets from rest_framework.decorators import api_view from rest_framework.exceptions import ValidationError +from rest_framework.generics import ListAPIView from rest_framework.response import Response from rest_framework.views import APIView from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter @@ -47,6 +48,7 @@ FeedbackLabel, Label, Model, + OsmUser, Training, ) from .serializers import ( @@ -59,8 +61,10 @@ FeedbackParamSerializer, FeedbackSerializer, LabelSerializer, + ModelCentroidSerializer, ModelSerializer, PredictionParamSerializer, + UserSerializer, ) from .tasks import train_model from .utils import get_dir_size, gpx_generator, process_rawdata, request_rawdata @@ -236,8 +240,49 @@ class ModelViewSet( permission_classes = [IsOsmAuthenticated] permission_allowed_methods = ["GET"] queryset = Model.objects.all() - serializer_class = ModelSerializer # connecting serializer - filterset_fields = ["status"] + filter_backends = ( + InBBoxFilter, # it will take bbox like this api/v1/model/?in_bbox=-90,29,-89,35 , + DjangoFilterBackend, + filters.SearchFilter, + filters.OrderingFilter, + ) + serializer_class = ModelSerializer + filterset_fields = { + "status": ["exact"], + "created_at": ["exact", "gt", "gte", "lt", "lte"], + "last_modified": ["exact", "gt", "gte", "lt", "lte"], + "created_by": ["exact"], + "id": ["exact"], + } + ordering_fields = ["created_at", "last_modified", "id", "status"] + search_fields = ["name"] + + +class ModelCentroidView(ListAPIView): + queryset = Model.objects.filter(status=0) ## only deliver the published model + serializer_class = ModelCentroidSerializer + filter_backends = ( + # InBBoxFilter, + DjangoFilterBackend, + filters.SearchFilter, + ) + filterset_fields = ["id"] + search_fields = ["name"] + pagination_class = None + + +class UsersView(ListAPIView): + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] + queryset = OsmUser.objects.all() + serializer_class = UserSerializer + filter_backends = ( + # InBBoxFilter, + DjangoFilterBackend, + filters.SearchFilter, + ) + filterset_fields = ["id"] + search_fields = ["username", "id"] class AOIViewSet(viewsets.ModelViewSet):