Skip to content

Commit

Permalink
Merge pull request #291 from hotosm/feature/vector-tiles-training-files
Browse files Browse the repository at this point in the history
Feature/vector tiles training files
  • Loading branch information
kshitijrajsharma authored Oct 17, 2024
2 parents 0497eea + 40b8060 commit 55d160c
Show file tree
Hide file tree
Showing 16 changed files with 357 additions and 93 deletions.
20 changes: 17 additions & 3 deletions backend/core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

@admin.register(Dataset)
class DatasetAdmin(geoadmin.OSMGeoAdmin):
list_display = ["name", "created_by"]
list_display = ["name", "user"]


@admin.register(Model)
class ModelAdmin(geoadmin.OSMGeoAdmin):
list_display = ["get_dataset_id", "name", "status", "created_at", "created_by"]
list_display = ["get_dataset_id", "name", "status", "created_at", "user"]

def get_dataset_id(self, obj):
return obj.dataset.id
Expand All @@ -28,7 +28,7 @@ class TrainingAdmin(geoadmin.OSMGeoAdmin):
"description",
"status",
"zoom_level",
"created_by",
"user",
"accuracy",
]
list_filter = ["status"]
Expand All @@ -47,3 +47,17 @@ class FeedbackAOIAdmin(geoadmin.OSMGeoAdmin):
@admin.register(Feedback)
class FeedbackAdmin(geoadmin.OSMGeoAdmin):
list_display = ["feedback_type", "training", "user", "created_at"]


@admin.register(Banner)
class BannerAdmin(admin.ModelAdmin):
list_display = ("message", "start_date", "end_date", "is_displayable")
list_filter = ("start_date", "end_date")
search_fields = ("message",)
readonly_fields = ("is_displayable",)

def is_displayable(self, obj):
return obj.is_displayable()

is_displayable.boolean = True
is_displayable.short_description = "Currently Displayable"
38 changes: 30 additions & 8 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.contrib.postgres.fields import ArrayField
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models

from django.utils import timezone
from login.models import OsmUser

# Create your models here.
Expand All @@ -15,7 +15,7 @@ class DatasetStatus(models.IntegerChoices):
DRAFT = -1

name = models.CharField(max_length=255)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
last_modified = models.DateTimeField(auto_now=True)
created_at = models.DateTimeField(auto_now_add=True)
source_imagery = models.URLField(blank=True, null=True)
Expand Down Expand Up @@ -47,6 +47,11 @@ class Label(models.Model):


class Model(models.Model):
BASE_MODEL_CHOICES = (
("RAMP", "RAMP"),
("YOLO", "YOLO"),
)

class ModelStatus(models.IntegerChoices):
ARCHIVED = 1
PUBLISHED = 0
Expand All @@ -57,9 +62,12 @@ class ModelStatus(models.IntegerChoices):
created_at = models.DateTimeField(auto_now_add=True)
last_modified = models.DateTimeField(auto_now=True)
description = models.TextField(max_length=500, null=True, blank=True)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
published_training = models.PositiveIntegerField(null=True, blank=True)
status = models.IntegerField(default=-1, choices=ModelStatus.choices) #
status = models.IntegerField(default=-1, choices=ModelStatus.choices)
base_model = models.CharField(
choices=BASE_MODEL_CHOICES, default="RAMP", max_length=10
)


class Training(models.Model):
Expand All @@ -81,14 +89,15 @@ class Training(models.Model):
models.PositiveIntegerField(),
size=4,
)
created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)
started_at = models.DateTimeField(null=True, blank=True)
finished_at = models.DateTimeField(null=True, blank=True)
accuracy = models.FloatField(null=True, blank=True)
epochs = models.PositiveIntegerField()
chips_length = models.PositiveIntegerField(default=0)
batch_size = models.PositiveIntegerField()
freeze_layers = models.BooleanField(default=False)
centroid = geomodels.PointField(srid=4326, null=True, blank=True)


class Feedback(models.Model):
Expand Down Expand Up @@ -146,6 +155,19 @@ class ApprovedPredictions(models.Model):
srid=4326
) ## Making this geometry field to support point/line prediction later on
approved_at = models.DateTimeField(auto_now_add=True)
approved_by = models.ForeignKey(
OsmUser, to_field="osm_id", on_delete=models.CASCADE
)
user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE)


class Banner(models.Model):
message = models.TextField()
start_date = models.DateTimeField(default=timezone.now)
end_date = models.DateTimeField(null=True, blank=True)

def is_displayable(self):
now = timezone.now()
return (self.start_date <= now) and (
self.end_date is None or self.end_date >= now
)

def __str__(self):
return self.message
21 changes: 16 additions & 5 deletions backend/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ class Meta:
model = Dataset
fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want
read_only_fields = (
"created_by",
"user",
"created_at",
"last_modified",
)

def create(self, validated_data):
user = self.context["request"].user
validated_data["created_by"] = user
validated_data["user"] = user
return super().create(validated_data)


Expand All @@ -46,7 +46,7 @@ class Meta:


class ModelSerializer(serializers.ModelSerializer):
created_by = UserSerializer(read_only=True)
user = UserSerializer(read_only=True)
accuracy = serializers.SerializerMethodField()
thumbnail_url = serializers.SerializerMethodField()

Expand All @@ -56,13 +56,13 @@ class Meta:
read_only_fields = (
"created_at",
"last_modified",
"created_by",
"user",
"published_training",
)

def create(self, validated_data):
user = self.context["request"].user
validated_data["created_by"] = user
validated_data["user"] = user
return super().create(validated_data)

# def get_training(self, obj):
Expand Down Expand Up @@ -393,3 +393,14 @@ def validate(self, data):
data["area_threshold"]
)
return data


class BannerSerializer(serializers.ModelSerializer):
class Meta:
model = Banner
fields = [
"id",
"message",
"start_date",
"end_date",
]
30 changes: 23 additions & 7 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
import logging
import os
import shutil
import subprocess
import sys
import tarfile
import traceback
from shutil import rmtree

from celery import shared_task
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone

from core.models import AOI, Feedback, FeedbackAOI, FeedbackLabel, Label, Training
from core.serializers import (
AOISerializer,
Expand All @@ -23,6 +18,11 @@
LabelFileSerializer,
)
from core.utils import bbox, is_dir_empty
from django.conf import settings
from django.contrib.gis.db.models.aggregates import Extent
from django.contrib.gis.geos import GEOSGeometry
from django.shortcuts import get_object_or_404
from django.utils import timezone

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,6 +135,10 @@ def train_model(
raise ValueError(
f"No AOI is attached with supplied dataset id:{dataset_id}, Create AOI first",
)
first_aoi_centroid = aois[0].geom.centroid
training_instance.centroid = first_aoi_centroid
training_instance.save()

for obj in aois:
bbox_coords = bbox(obj.geom.coords[0])
for z in zoom_level:
Expand Down Expand Up @@ -309,6 +313,18 @@ def train_model(
) as f:
f.write(json.dumps(aoi_serializer.data))

tippecanoe_command = f"""tippecanoe -o {os.path.join(output_path,"meta.pmtiles")} -Z7 -z18 -L aois:{ os.path.join(output_path, "aois.geojson")} -L labels:{os.path.join(output_path, "labels.geojson")} --force --read-parallel -rg --drop-densest-as-needed"""
logging.info("Starting to generate vector tiles for aois and labels")
try:
result = subprocess.run(
tippecanoe_command, shell=True, check=True, capture_output=True
)
logging.info(result.stdout.decode("utf-8"))
except subprocess.CalledProcessError as ex:
logger.error(ex.output)
raise ex
logging.info("Vector tile generation done !")

# copy aois and labels to preprocess output before compressing it to tar
shutil.copyfile(
os.path.join(output_path, "aois.geojson"),
Expand All @@ -332,7 +348,7 @@ def train_model(
training_instance.save()
response = {}
response["accuracy"] = float(final_accuracy)
# response["model_path"] = os.path.join(output_path, "checkpoint.tf")
response["tiles_path"] = os.path.join(output_path, "meta.pmtiles")
response["model_path"] = os.path.join(output_path, "checkpoint.h5")
response["graph_path"] = os.path.join(output_path, "graphs")
sys.stdout = sys.__stdout__
Expand Down
4 changes: 4 additions & 0 deletions backend/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .views import ( # APIStatus,
AOIViewSet,
ApprovedPredictionsViewSet,
BannerViewSet,
ConflateGeojson,
DatasetViewSet,
FeedbackAOIViewset,
Expand All @@ -26,6 +27,7 @@
UsersView,
download_training_data,
geojson2osmconverter,
get_kpi_stats,
publish_training,
run_task_status,
)
Expand All @@ -44,6 +46,7 @@
router.register(r"feedback", FeedbackViewset)
router.register(r"feedback-aoi", FeedbackAOIViewset)
router.register(r"feedback-label", FeedbackLabelViewset)
router.register(r"banner", BannerViewSet)


urlpatterns = [
Expand Down Expand Up @@ -71,6 +74,7 @@
"workspace/download/<path:lookup_dir>/", TrainingWorkspaceDownloadView.as_view()
),
path("workspace/<path:lookup_dir>/", TrainingWorkspaceView.as_view()),
path("kpi/stats/", get_kpi_stats, name="get_kpi_stats"),
]
if settings.ENABLE_PREDICTION_API:
urlpatterns.append(path("prediction/", PredictionView.as_view()))
Loading

0 comments on commit 55d160c

Please sign in to comment.