diff --git a/ide/static/css/dash_style.css b/ide/static/css/dash_style.css new file mode 100644 index 000000000..fe720deb5 --- /dev/null +++ b/ide/static/css/dash_style.css @@ -0,0 +1,21 @@ +.overlay { + opacity: 0; + z-index: -2; + height: 290px; + width: 240px; + background: rgb(34,47,62,0.9); + border-radius: 20px; + position: relative; + top: -310px; + transition: all .4s ease; +} + +.card { + transition: all .4s ease; +} + +.card:hover + .overlay, .overlay:hover { + opacity: 1; + z-index: 1; + transition: all .4s ease; +} diff --git a/ide/static/css/login_style.css b/ide/static/css/login_style.css index b4228b7f1..bf3c37ea9 100644 --- a/ide/static/css/login_style.css +++ b/ide/static/css/login_style.css @@ -15,6 +15,36 @@ cursor: pointer; } +#sidebar-logout-button { + background: rgb(205, 207, 210); + color: rgb(69, 80, 97); + text-align: center; + border-radius: 5px; + width: 110px; + margin: 0.2em; + transition: 0.2s; + position: relative; +} + +#sidebar-logout-button:hover { + cursor: pointer; +} + +#sidebar-dash-button { + background: rgb(205, 207, 210); + color: rgb(69, 80, 97); + text-align: center; + border-radius: 5px; + width: 130px; + margin: 0.2em; + transition: 0.2s; + position: relative; +} + +#sidebar-dash-button:hover { + cursor: pointer; +} + #sidebar-login-button span { position: absolute; left: 9px; @@ -67,6 +97,13 @@ } } +.long-buttons { + display: flex; + flex: 1; + flex-direction: row; + margin-top: -0.5em; +} + .login-panel { position: relative; width: 350px; @@ -265,4 +302,4 @@ position: absolute; top: 0px; right: 10px; -} \ No newline at end of file +} diff --git a/ide/static/css/searchbar_style.css b/ide/static/css/searchbar_style.css index d7e3c537a..02e86a05d 100644 --- a/ide/static/css/searchbar_style.css +++ b/ide/static/css/searchbar_style.css @@ -1,3 +1,7 @@ +body { + background: #F3F5F7; +} + .insert-layer-title { position: relative; } diff --git a/ide/static/js/card.js b/ide/static/js/card.js new file mode 100644 index 000000000..c57152847 --- /dev/null +++ b/ide/static/js/card.js @@ -0,0 +1,79 @@ +import React from 'react'; +import '../css/dash_style.css'; + +class Card extends React.Component { + render() { + return( +
+
+ +
+
+
+
+
+ +

+ +

+
+
+
+ this.props.ModelFunction(this.props.ModelID) + } style={{ + textDecoration: 'none', + color: '#fff', + cursor: 'pointer' + }}> +

+ +

+
+
+
+
+

{this.props.ModelName}

+
+ ); + } +} + +Card.propTypes = { + ModelName: React.PropTypes.string, + ModelID: React.PropTypes.number, + ModelFunction: React.PropTypes.func +}; + + +export default Card; diff --git a/ide/static/js/content.js b/ide/static/js/content.js index 9a0c92d4c..ba151f864 100644 --- a/ide/static/js/content.js +++ b/ide/static/js/content.js @@ -85,6 +85,7 @@ class Content extends React.Component { this.openModal = this.openModal.bind(this); this.closeModal = this.closeModal.bind(this); this.saveDb = this.saveDb.bind(this); + this.saveModel = this.saveModel.bind(this); this.loadDb = this.loadDb.bind(this); this.infoModal = this.infoModal.bind(this); this.faqModal = this.faqModal.bind(this); @@ -979,6 +980,34 @@ class Content extends React.Component { layer.info.phase = 0; this.setState({ net }); } + saveModel(){ + let modelData = this.state.net; + this.setState({ load: true }); + $.ajax({ + url: '/saveModel', + dataType: 'json', + type: 'POST', + data: { + net: JSON.stringify(modelData), + net_name: this.state.net_name, + user_id: this.getUserId(), + nextLayerId: this.state.nextLayerId + }, + success : function (response) { + if (response.result == 'success') { + this.modalContent = "Successfully Saved!"; + this.openModal(); + } + else if (response.result == 'error') { + this.addError(response.error); + } + this.setState({ load: false }); + }.bind(this), + error() { + this.setState({ load: false }); + } + }); + } saveDb(){ let netData = this.state.net; this.setState({ load: true }); @@ -1055,7 +1084,7 @@ class Content extends React.Component { // Note: this needs to be improved when handling conflict resolution to avoid // inconsistent states of model let nextLayerId = this.state.nextLayerId; - + let is_shared = false; this.setState({ load: true }); this.dismissAllErrors(); @@ -1072,6 +1101,13 @@ class Content extends React.Component { // while loading a model ensure paramete intialisation // for UI show/hide is not executed, it leads to inconsistent // data which cannot be used further + if (response.public_sharing == false) { + is_shared = false; + } + else { + is_shared = true; + } + console.log(response); nextLayerId = response.next_layer_id; this.initialiseImportedNet(response.net,response.net_name); if (Object.keys(response.net).length){ @@ -1083,8 +1119,10 @@ class Content extends React.Component { } this.setState({ load: false, - isShared: true, + isShared: is_shared, nextLayerId: parseInt(nextLayerId) + }, function() { + console.log("Shared value: " + this.state.isShared); }); }.bind(this), error() { @@ -1092,6 +1130,7 @@ class Content extends React.Component { } }); } + infoModal() { this.modalHeader = "About" this.modalContent = `Fabrik is an online collaborative platform to build and visualize deep\ @@ -1113,7 +1152,7 @@ class Content extends React.Component { here.
Q: What do the Train/Test buttons mean?
- A: They are two different modes of your model: + A: They are two different modes of your model: Train and Test - respectively for training your model with data and testing how and if it works.
Q: What does the import fuction do?
A: It allows you to import your previously created models in Caffe (.protoxt files), @@ -1127,7 +1166,7 @@ class Content extends React.Component { A: Please see the instructions listed here

- + If you have anymore questions, please visit Fabrik's Github page available here for more information.

); @@ -1282,6 +1321,7 @@ class Content extends React.Component { this.addNewLayer(layer); } } + render() { let loader = null; if (this.state.load) { @@ -1299,9 +1339,11 @@ class Content extends React.Component { ; + } + else { + var data_array = JSON.parse(localStorage.getItem("obj")); + var len = Object.keys(data_array).length/2; + var elements=[]; + for (var i = 1; i < len+1; i++) { + elements.push() + } + } + return ( +
+
+

DASHBOARD

+
+

+ +  CREATE NEW MODEL +

+
+
+ {elements} +
+
+ + +

{ this.modalHeader }

+ { this.modalContent } +
+
+ ); + } + else { + window.open("#","_self"); + return null; + } + } +} + +export default Dashboard; diff --git a/ide/static/js/dashbutton.js b/ide/static/js/dashbutton.js new file mode 100644 index 000000000..3a45a0a41 --- /dev/null +++ b/ide/static/js/dashbutton.js @@ -0,0 +1,22 @@ +import React from 'react'; + +class DashButton extends React.Component { + constructor(props) { + super(props); + this.openDash = this.openDash.bind(this); + } + openDash(){ + window.location.href = "/#/dashboard"; + } + render(){ + return( +
+ +
+ ); + } +} + +export default DashButton; diff --git a/ide/static/js/index.js b/ide/static/js/index.js index d387a3637..1448af664 100644 --- a/ide/static/js/index.js +++ b/ide/static/js/index.js @@ -1,12 +1,13 @@ - import React from 'react'; import { render } from 'react-dom'; import { Router, Route, hashHistory } from 'react-router'; import App from './app.js'; +import Dashboard from './dashboard.js'; import '../css/style.css'; render( + , document.getElementById('app') ); diff --git a/ide/static/js/login.js b/ide/static/js/login.js index a31214afa..7165f4dae 100644 --- a/ide/static/js/login.js +++ b/ide/static/js/login.js @@ -1,4 +1,5 @@ import React from 'react'; +import DashButton from './dashbutton'; class Login extends React.Component { constructor(props) { @@ -23,6 +24,7 @@ class Login extends React.Component { contentType: false, success: function (response) { if (response) { + localStorage.removeItem("userID"); this.setState({ loginState: false }); this.props.setUserId(null); this.props.setUserName(null); @@ -67,6 +69,7 @@ class Login extends React.Component { if (response.result) { this.setState({ loginState: response.result }); this.props.setUserId(response.user_id); + localStorage.setItem("userID",response.user_id); this.props.setUserName(response.username); if (showNotification) { @@ -181,12 +184,15 @@ class Login extends React.Component { if(this.state.loginState) { return (
-
done
diff --git a/ide/static/js/topBar.js b/ide/static/js/topBar.js index b3a02e437..b95712e3f 100644 --- a/ide/static/js/topBar.js +++ b/ide/static/js/topBar.js @@ -4,20 +4,14 @@ import ReactTooltip from 'react-tooltip'; class TopBar extends React.Component { constructor(props) { super(props); - this.checkURL = this.checkURL.bind(this); + this.state = {isShared: false}; } - checkURL() { - let url = window.location.href; - let urlParams = url.indexOf("load"); - - if(urlParams != -1) { - return true; - } - return false; + componentWillReceiveProps(newProps){ + this.setState({isShared: newProps.isShared}); } render() { let content = null; - if (this.checkURL()) { + if (this.state.isShared == true) { content = (
{content} +
+
+
+ +
+
+
@@ -116,11 +120,13 @@ TopBar.propTypes = { exportNet: React.PropTypes.func, importNet: React.PropTypes.func, saveDb: React.PropTypes.func, + saveModel: React.PropTypes.func, loadDb: React.PropTypes.func, zooModal: React.PropTypes.func, textboxModal: React.PropTypes.func, urlModal: React.PropTypes.func, - updateHistoryModal: React.PropTypes.func + updateHistoryModal: React.PropTypes.func, + isShared: React.PropTypes.bool }; export default TopBar; diff --git a/ide/urls.py b/ide/urls.py index 7798fdc97..140087b11 100644 --- a/ide/urls.py +++ b/ide/urls.py @@ -3,7 +3,9 @@ from django.conf.urls.static import static from django.conf import settings from views import index, calculate_parameter, fetch_layer_shape -from views import load_from_db, save_to_db, fetch_model_history +from views import load_from_db, load_model_from_db, \ + delete_model_from_db, save_to_db, save_model_to_db, \ + fetch_model_history urlpatterns = [ url(r'^$', index), @@ -14,9 +16,15 @@ url(r'^keras/', include('keras_app.urls')), url(r'^tensorflow/', include('tensorflow_app.urls')), url(r'^save$', save_to_db, name='saveDB'), + url(r'^saveModel$', save_model_to_db, name='saveModel'), url(r'^load*', load_from_db, name='loadDB'), + url(r'^deleteModel$', delete_model_from_db, name='deleteModel'), + url(r'^getModel$', load_model_from_db, name='getModelData'), url(r'^model_history', fetch_model_history, name='model-history'), - url(r'^model_parameter/', calculate_parameter, name='calculate-parameter'), - url(r'^layer_parameter/', fetch_layer_shape, name='fetch-layer-shape') -] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) + \ - static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) + url(r'^model_parameter/', calculate_parameter, + name='calculate-parameter'), + url(r'^layer_parameter/', fetch_layer_shape, + name='fetch-layer-shape'), + ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) \ + + static(settings.STATIC_URL, + document_root=settings.STATIC_ROOT) diff --git a/ide/views.py b/ide/views.py index 9cd03ea87..4cc94b8b4 100644 --- a/ide/views.py +++ b/ide/views.py @@ -8,7 +8,8 @@ from django.http import JsonResponse from django.views.decorators.csrf import csrf_exempt from django.contrib.auth.models import User -from utils.shapes import get_shapes, get_layer_shape, handle_concat_layer +from utils.shapes import get_shapes, get_layer_shape, \ + handle_concat_layer def index(request): @@ -24,38 +25,65 @@ def fetch_layer_shape(request): net[layerId]['shape'] = {} net[layerId]['shape']['input'] = None net[layerId]['shape']['output'] = None - dataLayers = ['ImageData', 'Data', 'HDF5Data', 'Input', 'WindowData', 'MemoryData', 'DummyData'] + dataLayers = [ + 'ImageData', + 'Data', + 'HDF5Data', + 'Input', + 'WindowData', + 'MemoryData', + 'DummyData', + ] # Obtain input shape of new layer - if (net[layerId]['info']['type'] == "Concat"): + + if net[layerId]['info']['type'] == 'Concat': for parentLayerId in net[layerId]['connection']['input']: + # Check if parent layer have shapes - if (net[parentLayerId]['shape']['output']): - net[layerId]['shape']['input'] = handle_concat_layer(net[layerId], net[parentLayerId]) - elif (not (net[layerId]['info']['type'] in dataLayers)): - if (len(net[layerId]['connection']['input']) > 0): + + if net[parentLayerId]['shape']['output']: + net[layerId]['shape']['input'] = \ + handle_concat_layer(net[layerId], + net[parentLayerId]) + elif not net[layerId]['info']['type'] in dataLayers: + if len(net[layerId]['connection']['input']) > 0: parentLayerId = net[layerId]['connection']['input'][0] + # Check if parent layer have shapes - if (net[parentLayerId]['shape']['output']): - net[layerId]['shape']['input'] = net[parentLayerId]['shape']['output'][:] + + if net[parentLayerId]['shape']['output']: + net[layerId]['shape']['input'] = \ + (net[parentLayerId]['shape']['output'])[:] # Obtain output shape of new layer - if (net[layerId]['info']['type'] in dataLayers): + + if net[layerId]['info']['type'] in dataLayers: + # handling Data Layers separately - if ('dim' in net[layerId]['params'] and len(net[layerId]['params']['dim'])): + + if 'dim' in net[layerId]['params'] \ + and len(net[layerId]['params']['dim']): + # layers with empty dim parameter can't be passed - net[layerId]['shape']['input'], net[layerId]['shape']['output'] =\ - get_layer_shape(net[layerId]) - elif ('dim' not in net[layerId]['params']): + + (net[layerId]['shape']['input'], + net[layerId]['shape']['output']) = \ + get_layer_shape(net[layerId]) + elif 'dim' not in net[layerId]['params']: + # shape calculation for layers with no dim param - net[layerId]['shape']['input'], net[layerId]['shape']['output'] =\ - get_layer_shape(net[layerId]) + + (net[layerId]['shape']['input'], + net[layerId]['shape']['output']) = \ + get_layer_shape(net[layerId]) else: - if (net[layerId]['shape']['input']): - net[layerId]['shape']['output'] = get_layer_shape(net[layerId]) + if net[layerId]['shape']['input']: + net[layerId]['shape']['output'] = \ + get_layer_shape(net[layerId]) except BaseException: - return JsonResponse({ - 'result': 'error', 'error': str(sys.exc_info()[1])}) + return JsonResponse({'result': 'error', + 'error': str(sys.exc_info()[1])}) return JsonResponse({'result': 'success', 'net': net}) @@ -64,57 +92,118 @@ def calculate_parameter(request): if request.method == 'POST': net = yaml.safe_load(request.POST.get('net')) try: + # While calling get_shapes we need to remove the flag # added in frontend to show the parameter on pane + netObj = copy.deepcopy(net) for layerId in netObj: for param in netObj[layerId]['params']: - netObj[layerId]['params'][param] = netObj[layerId]['params'][param][0] + netObj[layerId]['params'][param] = \ + netObj[layerId]['params'][param][0] + # use get_shapes method to obtain shapes of each layer + netObj = get_shapes(netObj) for layerId in net: net[layerId]['shape'] = {} net[layerId]['shape']['input'] = netObj[layerId]['shape']['input'] - net[layerId]['shape']['output'] = netObj[layerId]['shape']['output'] + net[layerId]['shape']['output'] = \ + netObj[layerId]['shape']['output'] except BaseException: - return JsonResponse({ - 'result': 'error', 'error': str(sys.exc_info()[1])}) + return JsonResponse({'result': 'error', + 'error': str(sys.exc_info()[1])}) return JsonResponse({'result': 'success', 'net': net}) @csrf_exempt -def save_to_db(request): +def delete_model_from_db(request): + if request.method == 'POST': + if 'userID' in request.POST: + userID = request.POST.get('userID') + model_id = request.POST.get('modelid') + model = Network.objects.get(id=model_id) + if model.author_id == int(userID): + model.delete() + return JsonResponse({'result': 'success', + 'data': 'Model successfully deleted!'}) + else: + return JsonResponse({'result': 'error', + 'error': "This model doesn't belong to you!"}) + + +@csrf_exempt +def save(request, public_sharing): if request.method == 'POST': net = request.POST.get('net') net_name = request.POST.get('net_name') user_id = request.POST.get('user_id') next_layer_id = request.POST.get('nextLayerId') - public_sharing = True user = None + if public_sharing is True: + tag = "ModelShared" + else: + tag = "ModelNotShared" if net_name == '': net_name = 'Net' - try: - # making model sharing public by default for now - # TODO: Prvilege on Sharing - if user_id: - user_id = int(user_id) - user = User.objects.get(id=user_id) - - # create a new model on share event - model = Network(name=net_name, public_sharing=public_sharing, author=user) - model.save() - # create first version of model - model_version = NetworkVersion(network=model, network_def=net) - model_version.save() - # create initial update for nextLayerId - model_update = NetworkUpdates(network_version=model_version, - updated_data=json.dumps({'nextLayerId': next_layer_id}), - tag='ModelShared') - model_update.save() - - return JsonResponse({'result': 'success', 'id': model.id}) - except: - return JsonResponse({'result': 'error', 'error': str(sys.exc_info()[1])}) + + if Network.objects.filter(name=net_name).exists(): + # Update the exising json field + try: + if user_id: + user_id = int(user_id) + user = User.objects.get(id=user_id) + # load the model with the net name + model = Network.objects.get(name=net_name) + model_id = model.id + # update the model with network id same as model id + existing_model = \ + NetworkVersion.objects.get(network_id=model_id) + existing_model.network_def = net + existing_model.save() + return JsonResponse({'result': 'success', + 'id': model.id}) + except: + return JsonResponse({'result': 'error', + 'error': str(sys.exc_info()[1])}) + else: + try: + if user_id: + user_id = int(user_id) + user = User.objects.get(id=user_id) + # create a new model on save event + model = Network(name=net_name, + public_sharing=public_sharing, + author=user) + model.save() + # create first version of model + model_version = NetworkVersion(network=model, + network_def=net) + model_version.save() + # create initial update for nextLayerId + model_update = \ + NetworkUpdates(network_version=model_version, + updated_data=json.dumps( + {'nextLayerId': next_layer_id}), + tag=tag) + model_update.save() + return JsonResponse({'result': 'success', + 'id': model.id}) + except: + return JsonResponse({'result': 'error', + 'error': str(sys.exc_info()[1])}) + + +@csrf_exempt +def save_model_to_db(request): + response = save(request, False) + return response + + +@csrf_exempt +def save_to_db(request): + response = save(request, True) + return response def create_network_version(network_def, updates_batch): @@ -129,7 +218,9 @@ def create_network_version(network_def, updates_batch): next_layer_id = updated_data['nextLayerId'] if tag == 'UpdateParam': + # Update Param UI event handling + param = updated_data['param'] layer_id = updated_data['layerId'] value = updated_data['value'] @@ -138,9 +229,10 @@ def create_network_version(network_def, updates_batch): network_def[layer_id]['props'][param] = value else: network_def[layer_id]['params'][param][0] = value - elif tag == 'DeleteLayer': + # Delete layer UI event handling + layer_id = updated_data['layerId'] input_layer_ids = network_def[layer_id]['connection']['input'] output_layer_ids = network_def[layer_id]['connection']['output'] @@ -152,48 +244,53 @@ def create_network_version(network_def, updates_batch): network_def[output_layer_id]['connection']['input'].remove(layer_id) del network_def[layer_id] - elif tag == 'AddLayer': + # Add layer UI event handling + prev_layer_id = updated_data['prevLayerId'] new_layer_id = updated_data['layerId'] if isinstance(prev_layer_id, list): for layer_id in prev_layer_id: - network_def[layer_id]['connection']['output'].append(new_layer_id) + network_def[layer_id]['connection']['output' + ].append(new_layer_id) else: - network_def[prev_layer_id]['connection']['output'].append(new_layer_id) + network_def[prev_layer_id]['connection']['output' + ].append(new_layer_id) network_def[new_layer_id] = updated_data['layer'] - elif tag == 'AddComment': + layer_id = updated_data['layerId'] comment = updated_data['comment'] - if ('comments' not in network_def[layer_id]): + if 'comments' not in network_def[layer_id]: network_def[layer_id]['comments'] = [] network_def[layer_id]['comments'].append(comment) - return { - 'network': network_def, - 'next_layer_id': next_layer_id - } + return {'network': network_def, 'next_layer_id': next_layer_id} def get_network_version(netObj): - network_version = NetworkVersion.objects.filter(network=netObj).order_by('-created_on')[0] - updates_batch = NetworkUpdates.objects.filter(network_version=network_version).order_by('created_on') + network_version = \ + NetworkVersion.objects.filter(network=netObj).order_by('-created_on' + )[0] + updates_batch = NetworkUpdates.objects.filter( + network_version=network_version).order_by('created_on') - return create_network_version(network_version.network_def, updates_batch) + return create_network_version(network_version.network_def, + updates_batch) def get_checkpoint_version(netObj, checkpoint_id): network_update = NetworkUpdates.objects.get(id=checkpoint_id) network_version = network_update.network_version - updates_batch = NetworkUpdates.objects.filter(network_version=network_version)\ - .filter(created_on__lte=network_update.created_on)\ - .order_by('created_on') - return create_network_version(network_version.network_def, updates_batch) + updates_batch = NetworkUpdates.objects.filter( + network_version=network_version).filter( + created_on__lte=network_update.created_on).order_by('created_on') + return create_network_version(network_version.network_def, + updates_batch) @csrf_exempt @@ -201,56 +298,82 @@ def load_from_db(request): if request.method == 'POST': if 'proto_id' in request.POST: try: - model = Network.objects.get(id=int(request.POST['proto_id'])) + model = \ + Network.objects.get(id=int(request.POST['proto_id' + ])) version_id = None data = {} - if 'version_id' in request.POST and request.POST['version_id'] != '': + if 'version_id' in request.POST \ + and request.POST['version_id'] != '': + # added for loading any previous version of model + version_id = int(request.POST['version_id']) data = get_checkpoint_version(model, version_id) else: + # fetch the required version of model + data = get_network_version(model) net = data['network'] next_layer_id = data['next_layer_id'] - - # authorizing the user for access to model - if not model.public_sharing: - return JsonResponse({'result': 'error', - 'error': 'Permission denied for access to model'}) except Exception: + return JsonResponse({'result': 'error', 'error': 'No network file found'}) - return JsonResponse({'result': 'success', 'net': net, 'net_name': model.name, - 'next_layer_id': next_layer_id}) + return JsonResponse({ + 'result': 'success', + 'net': net, + 'net_name': model.name, + 'next_layer_id': next_layer_id, + 'public_sharing': model.public_sharing, + }) if request.method == 'GET': return index(request) +@csrf_exempt +def load_model_from_db(request): + if request.method == 'POST': + if 'userID' in request.POST: + userID = request.POST.get('userID') + if Network.objects.filter(author=userID).exists(): + data = {} + models = Network.objects.filter(author=userID) + i = 1 + for mod in models: + data_index1 = 'Model%d_Name' % i + data_index2 = 'Model%d_ID' % i + data[data_index1] = mod.name + data[data_index2] = mod.id + i += 1 + return JsonResponse({'result': 'success', 'data': data}) + else: + return JsonResponse({'result': 'error', + 'error': 'No models found'}) + + @csrf_exempt def fetch_model_history(request): if request.method == 'POST': try: network_id = int(request.POST['net_id']) network = Network.objects.get(id=network_id) - network_versions = NetworkVersion.objects.filter(network=network).order_by('created_on') + network_versions = NetworkVersion.objects.filter( + network=network).order_by('created_on') modelHistory = {} for version in network_versions: - network_updates = NetworkUpdates.objects.filter(network_version=version)\ - .order_by('created_on') + network_updates = NetworkUpdates.objects.filter( + network_version=version).order_by('created_on') for update in network_updates: modelHistory[update.id] = update.tag - return JsonResponse({ - 'result': 'success', - 'data': modelHistory - }) + return JsonResponse({'result': 'success', + 'data': modelHistory}) except Exception: - return JsonResponse({ - 'result': 'error', - 'error': 'Unable to load model history' - }) + return JsonResponse({'result': 'error', + 'error': 'Unable to load model history'}) diff --git a/settings/test.py b/settings/test.py index 540f7a344..64921989c 100644 --- a/settings/test.py +++ b/settings/test.py @@ -1,4 +1,4 @@ -from .common import * # noqa: ignore=F405 +from .common import * # flake8: noqa # Database # https://docs.djangoproject.com/en/1.9/ref/settings/#databases @@ -8,10 +8,10 @@ DATABASES = { 'default': { 'ENGINE': 'django.db.backends.postgresql_psycopg2', - 'NAME': 'fabrik', - 'USER': 'admin', - 'PASSWORD': 'fabrik', - 'HOST': 'localhost', + 'NAME': 'fabrik', # Change this to 'postgres' if you're using docker + 'USER': 'admin', # Change this to 'postgres' if you're using docker + 'PASSWORD': 'fabrik', # Change this to 'postgres' if you're using docker + 'HOST': 'localhost', # Change this to 'db' if you're using docker 'PORT': 5432, } } diff --git a/tests/unit/caffe_app/test_db.py b/tests/unit/caffe_app/test_db.py index ca0fa4a88..8dea0a934 100644 --- a/tests/unit/caffe_app/test_db.py +++ b/tests/unit/caffe_app/test_db.py @@ -51,3 +51,83 @@ def test_load_nofile(self): response = json.loads(response.content) self.assertEqual(response['result'], 'error') self.assertEqual(response['error'], 'No network file found') + + +class SaveModelToDBTest(unittest.TestCase): + + def setUp(self): + self.client = Client() + + def test_save_json1(self): + tests = open(os.path.join(settings.BASE_DIR, 'tests', 'unit', 'ide', + 'caffe_export_test.json'), 'r') + net = json.load(tests)['net'] + response = self.client.post( + reverse('saveModel'), + {'net': net, 'net_name': 'netname'}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'success') + + def test_load1(self): + u_3 = User(id=3, username='user_3') + u_3.save() + u_4 = User(id=4, username='user_4') + u_4.save() + model = Network(name='net') + model.save() + model_version = NetworkVersion(network=model, network_def={}) + model_version.save() + + response = self.client.post( + reverse('saveModel'), + {'net': '{"net": "testnet"}', 'net_name': 'name'}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'success') + self.assertTrue('id' in response) + proto_id = response['id'] + response = self.client.post(reverse('loadDB'), {'proto_id': proto_id}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'success') + self.assertEqual(response['net_name'], 'name') + + def test_load_nofile1(self): + response = self.client.post(reverse('loadDB'), + {'proto_id': 'inexistent'}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'error') + self.assertEqual(response['error'], 'No network file found') + + +class LoadModelFromDB(unittest.TestCase): + + def setUp(self): + self.client = Client() + + def test_load_model(self): + u_5 = User(id=5, username='user_5') + u_5.save() + model = Network(id=9, name='test_net', author_id='5') + model.save() + response = self.client.post( + reverse('getModelData'), {'userID': '5'}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'success') + self.assertEqual(response['data']['Model1_Name'], 'test_net') + self.assertEqual(response['data']['Model1_ID'], 9) + + +class DeleteModelFromDB(unittest.TestCase): + + def setUp(self): + self.client = Client() + + def test_delete_model(self): + u_6 = User(id=6, username='user_6') + u_6.save() + model = Network(id=10, name='test_net2', author_id='6') + model.save() + response = self.client.post( + reverse('deleteModel'), {'userID': '6', 'modelid': '10'}) + response = json.loads(response.content) + self.assertEqual(response['result'], 'success') + self.assertEqual(Network.objects.filter(id=10).exists(), False)