Skip to content

Commit

Permalink
Merge pull request #286 from hotosm/feature/model-filters
Browse files Browse the repository at this point in the history
Feature : Model filters - Advancement in GET API
  • Loading branch information
kshitijrajsharma authored Oct 7, 2024
2 parents 3412739 + 527827e commit 669ceae
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 24 deletions.
12 changes: 8 additions & 4 deletions backend/aiproject/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
https://docs.aiproject.com/en/3.1/ref/settings/
"""

import os
import logging
import os
from socket import gethostbyname, gethostname

import dj_database_url
import environ
from corsheaders.defaults import default_headers
from socket import gethostbyname, gethostname

env = environ.Env()

Expand Down Expand Up @@ -102,6 +103,7 @@
CORS_ORIGIN_WHITELIST = ALLOWED_ORIGINS

CORS_ORIGIN_ALLOW_ALL = env("CORS_ORIGIN_ALLOW_ALL", default=False)
DEFAULT_PAGINATION_SIZE = env("DEFAULT_PAGINATION_SIZE", default=50)

REST_FRAMEWORK = {
"DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.coreapi.AutoSchema",
Expand All @@ -110,6 +112,8 @@
"rest_framework.authentication.BasicAuthentication",
"login.authentication.OsmAuthentication",
],
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.LimitOffsetPagination",
"PAGE_SIZE": DEFAULT_PAGINATION_SIZE,
}

ROOT_URLCONF = "aiproject.urls"
Expand Down Expand Up @@ -208,7 +212,7 @@
}
}
# get ramp home and set it to environ
RAMP_HOME = env("RAMP_HOME",default=None)
RAMP_HOME = env("RAMP_HOME", default=None)
if RAMP_HOME:
os.environ["RAMP_HOME"] = RAMP_HOME

Expand All @@ -220,4 +224,4 @@
ENABLE_PREDICTION_API = env("ENABLE_PREDICTION_API", default=False)


TEST_RUNNER = 'tests.test_runners.NoDestroyTestRunner'
TEST_RUNNER = "tests.test_runners.NoDestroyTestRunner"
1 change: 1 addition & 0 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ModelStatus(models.IntegerChoices):
name = models.CharField(max_length=255)
created_at = models.DateTimeField(auto_now_add=True)
last_modified = models.DateTimeField(auto_now=True)
description = models.TextField(max_length=500, null=True, blank=True)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
published_training = models.PositiveIntegerField(null=True, blank=True)
status = models.IntegerField(default=-1, choices=ModelStatus.choices) #
Expand Down
75 changes: 58 additions & 17 deletions backend/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,31 @@ def create(self, validated_data):
return super().create(validated_data)


class UserSerializer(serializers.ModelSerializer):
class Meta:
model = OsmUser
fields = [
"osm_id",
"username",
# "is_superuser",
# "is_active",
# "is_staff",
"date_joined",
# "email",
"img_url",
# "user_permissions",
]


class ModelSerializer(
serializers.ModelSerializer
): # serializers are used to translate models objects to api
created_by = UserSerializer(read_only=True)
accuracy = serializers.SerializerMethodField()

class Meta:
model = Model
fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want
fields = "__all__"
read_only_fields = (
"created_at",
"last_modified",
Expand All @@ -47,6 +66,44 @@ def create(self, validated_data):
validated_data["created_by"] = user
return super().create(validated_data)

def get_accuracy(
self, obj
): ## this might have performance problem when db grows bigger , consider adding indexes / view in db
training = Training.objects.filter(id=obj.published_training).first()
if training:
return training.accuracy
return None


class ModelCentroidSerializer(GeoFeatureModelSerializer):
geometry = serializers.SerializerMethodField()
mid = serializers.IntegerField(source="id")

class Meta:
model = Model
geo_field = "geometry"
fields = ("mid", "name", "geometry")

def get_geometry(self, obj):
"""
Get the centroid of the AOI linked to the dataset of the given model.
"""
aoi = AOI.objects.filter(dataset=obj.dataset).first()
if aoi and aoi.geom:
return {
"type": "Point",
"coordinates": aoi.geom.centroid.coords,
}
return None

# def to_representation(self, instance):
# """
# Override to_representation to customize GeoJSON structure.
# """
# representation = super().to_representation(instance)
# representation["properties"]["id"] = representation.pop("id")
# return representation


class AOISerializer(
GeoFeatureModelSerializer
Expand Down Expand Up @@ -314,19 +371,3 @@ def validate(self, data):
data["area_threshold"]
)
return data


class UserSerializer(serializers.ModelSerializer):
class Meta:
model = OsmUser
fields = [
"osm_id",
"username",
"is_superuser",
"is_active",
"is_staff",
"date_joined",
"email",
"img_url",
"user_permissions",
]
4 changes: 4 additions & 0 deletions backend/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
GenerateFeedbackAOIGpxView,
GenerateGpxView,
LabelViewSet,
ModelCentroidView,
ModelViewSet,
RawdataApiAOIView,
RawdataApiFeedbackView,
TrainingViewSet,
TrainingWorkspaceDownloadView,
TrainingWorkspaceView,
UsersView,
download_training_data,
geojson2osmconverter,
publish_training,
Expand Down Expand Up @@ -51,6 +53,8 @@
"label/feedback/osm/fetch/<int:feedbackaoi_id>/",
RawdataApiFeedbackView.as_view(),
),
path("users/", UsersView.as_view(), name="user-list-view"),
path("models/centroid/", ModelCentroidView.as_view(), name="model-centroid"),
# path("download/<int:dataset_id>/", download_training_data),
path("training/status/<str:run_id>/", run_task_status),
path("training/publish/<int:training_id>/", publish_training),
Expand Down
51 changes: 48 additions & 3 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from geojson2osm import geojson2osm
from orthogonalizer import othogonalize_poly
from osmconflator import conflate_geojson
from rest_framework import decorators, serializers, status, viewsets
from rest_framework import decorators, filters, serializers, status, viewsets
from rest_framework.decorators import api_view
from rest_framework.exceptions import ValidationError
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter
Expand All @@ -47,6 +48,7 @@
FeedbackLabel,
Label,
Model,
OsmUser,
Training,
)
from .serializers import (
Expand All @@ -59,8 +61,10 @@
FeedbackParamSerializer,
FeedbackSerializer,
LabelSerializer,
ModelCentroidSerializer,
ModelSerializer,
PredictionParamSerializer,
UserSerializer,
)
from .tasks import train_model
from .utils import get_dir_size, gpx_generator, process_rawdata, request_rawdata
Expand Down Expand Up @@ -236,8 +240,49 @@ class ModelViewSet(
permission_classes = [IsOsmAuthenticated]
permission_allowed_methods = ["GET"]
queryset = Model.objects.all()
serializer_class = ModelSerializer # connecting serializer
filterset_fields = ["status"]
filter_backends = (
InBBoxFilter, # it will take bbox like this api/v1/model/?in_bbox=-90,29,-89,35 ,
DjangoFilterBackend,
filters.SearchFilter,
filters.OrderingFilter,
)
serializer_class = ModelSerializer
filterset_fields = {
"status": ["exact"],
"created_at": ["exact", "gt", "gte", "lt", "lte"],
"last_modified": ["exact", "gt", "gte", "lt", "lte"],
"created_by": ["exact"],
"id": ["exact"],
}
ordering_fields = ["created_at", "last_modified", "id", "status"]
search_fields = ["name"]


class ModelCentroidView(ListAPIView):
queryset = Model.objects.filter(status=0) ## only deliver the published model
serializer_class = ModelCentroidSerializer
filter_backends = (
# InBBoxFilter,
DjangoFilterBackend,
filters.SearchFilter,
)
filterset_fields = ["id"]
search_fields = ["name"]
pagination_class = None


class UsersView(ListAPIView):
authentication_classes = [OsmAuthentication]
permission_classes = [IsOsmAuthenticated]
queryset = OsmUser.objects.all()
serializer_class = UserSerializer
filter_backends = (
# InBBoxFilter,
DjangoFilterBackend,
filters.SearchFilter,
)
filterset_fields = ["id"]
search_fields = ["username", "id"]


class AOIViewSet(viewsets.ModelViewSet):
Expand Down

0 comments on commit 669ceae

Please sign in to comment.