Skip to content
This repository has been archived by the owner on Dec 3, 2019. It is now read-only.

PostgreSQLConnection refactoring #152

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

package com.github.mauricio.async.db.postgresql

import com.github.mauricio.async.db.QueryResult
import com.github.mauricio.async.db.column.{ColumnEncoderRegistry, ColumnDecoderRegistry}
import com.github.mauricio.async.db.exceptions.{InsufficientParametersException, ConnectionStillRunningQueryException}
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference}

import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry}
import com.github.mauricio.async.db.exceptions.{ConnectionStillRunningQueryException, InsufficientParametersException}
import com.github.mauricio.async.db.general.MutableResultSet
import com.github.mauricio.async.db.pool.TimeoutScheduler
import com.github.mauricio.async.db.postgresql.codec.{PostgreSQLConnectionDelegate, PostgreSQLConnectionHandler}
import com.github.mauricio.async.db.postgresql.column.{PostgreSQLColumnDecoderRegistry, PostgreSQLColumnEncoderRegistry}
import com.github.mauricio.async.db.postgresql.exceptions._
import com.github.mauricio.async.db.postgresql.messages.backend._
import com.github.mauricio.async.db.postgresql.messages.frontend._
import com.github.mauricio.async.db.util._
import com.github.mauricio.async.db.{Configuration, Connection}
import java.util.concurrent.atomic.{AtomicLong,AtomicInteger,AtomicReference}
import messages.backend._
import messages.frontend._
import scala.Some
import scala.concurrent._
import com.github.mauricio.async.db.{Configuration, Connection, QueryResult}
import io.netty.channel.EventLoopGroup
import java.util.concurrent.CopyOnWriteArrayList

import scala.concurrent._

object PostgreSQLConnection {
final val Counter = new AtomicLong()
Expand Down Expand Up @@ -70,26 +70,27 @@ class PostgreSQLConnection
private val parsedStatements = new scala.collection.mutable.HashMap[String, PreparedStatementHolder]()
private var authenticated = false

private val connectionFuture = Promise[Connection]()

private var recentError = false
private val queryPromiseReference = new AtomicReference[Option[Promise[QueryResult]]](None)
private var currentQuery: Option[MutableResultSet[PostgreSQLColumnData]] = None
private var recentError : Option[Throwable] = None
private val resultProcessorReference = new AtomicReference[Option[ResultProcessor]](None)
private var currentPreparedStatement: Option[PreparedStatementHolder] = None
private var version = Version(0,0,0)
private var notifyListeners = new CopyOnWriteArrayList[NotificationResponse => Unit]()

private var queryResult: Option[QueryResult] = None
private val notifyListeners = new CopyOnWriteArrayList[NotificationResponse => Unit]()
private var portalSuspended = false

override def eventLoopGroup : EventLoopGroup = group
def isReadyForQuery: Boolean = this.queryPromise.isEmpty

private def resultProcessor: Option[ResultProcessor] = resultProcessorReference.get()

def isReadyForQuery: Boolean = this.resultProcessor.isEmpty

def connect: Future[Connection] = {
val processor = new ConnectResultProcessor()
setResultProcessor(processor)
this.connectionHandler.connect.onFailure {
case e => this.connectionFuture.tryFailure(e)
case e => processor.failure(e)
}

this.connectionFuture.future
processor.connectionPromise.future
}

override def disconnect: Future[Connection] = this.connectionHandler.disconnect.map( c => this )
Expand All @@ -103,7 +104,7 @@ class PostgreSQLConnection
validateQuery(query)

val promise = Promise[QueryResult]()
this.setQueryPromise(promise)
this.setResultProcessor(new PromiseResultProcessor(promise))

write(new QueryMessage(query))
addTimeout(promise,configuration.queryTimeout)
Expand All @@ -114,18 +115,19 @@ class PostgreSQLConnection
validateQuery(query)

val promise = Promise[QueryResult]()
this.setQueryPromise(promise)
val processor: PromiseResultProcessor = new PromiseResultProcessor(promise)
this.setResultProcessor(processor)

val holder = this.parsedStatements.getOrElseUpdate(query,
new PreparedStatementHolder( query, preparedStatementsCounter.incrementAndGet ))

if (holder.paramsCount != values.length) {
this.clearQueryPromise
this.clearResultProcessor
throw new InsufficientParametersException(holder.paramsCount, values)
}

this.currentPreparedStatement = Some(holder)
this.currentQuery = Some(new MutableResultSet(holder.columnDatas))
processor.columnTypes = holder.columnDatas
write(
if (holder.prepared)
new PreparedStatementExecuteMessage(holder.statementId, holder.realQuery, values, this.encoderRegistry)
Expand All @@ -137,32 +139,29 @@ class PostgreSQLConnection
promise.future
}

override def onError( exception : Throwable ) {
this.setErrorOnFutures(exception)
override def onError(exception : Throwable) {
recentError = Some(exception)
}

def hasRecentError: Boolean = this.recentError
def hasRecentError: Boolean = this.recentError.isDefined

private def setErrorOnFutures(e: Throwable) {
this.recentError = true

log.error("Error on connection", e)

if (!this.connectionFuture.isCompleted) {
this.connectionFuture.failure(e)
this.disconnect
}

this.portalSuspended = false
this.currentPreparedStatement.map(p => this.parsedStatements.remove(p.query))
this.currentPreparedStatement = None
this.failQueryPromise(e)
this.failQuery(e)
}

override def onReadyForQuery() {
this.connectionFuture.trySuccess(this)

this.recentError = false
queryResult.foreach(this.succeedQueryPromise)
if (recentError.isDefined) {
setErrorOnFutures(recentError.get)
this.recentError = None
} else if (portalSuspended) {
portalSuspended = false
resultProcessor.get.portalSuspended()
} else {
this.clearResultProcessor.complete()
}
}

override def onError(m: ErrorMessage) {
Expand All @@ -171,12 +170,12 @@ class PostgreSQLConnection
val error = new GenericDatabaseException(m)
error.fillInStackTrace()

this.setErrorOnFutures(error)
recentError = Some(error)
}

override def onCommandComplete(m: CommandCompleteMessage) {
resultProcessor.get.completeCommand(m.rowsAffected, m.statusMessage)
this.currentPreparedStatement = None
queryResult = Some(new QueryResult(m.rowsAffected, m.statusMessage, this.currentQuery))
}

override def onParameterStatus(m: ParameterStatusMessage) {
Expand All @@ -187,24 +186,26 @@ class PostgreSQLConnection
}

override def onDataRow(m: DataRowMessage) {
val items = new Array[Any](m.values.size)
val items = new Array[Any](m.values.length)
var x = 0

while ( x < m.values.size ) {
val processor: ResultProcessor = this.resultProcessor.get
val columnsData = processor.columnTypes
while ( x < m.values.length ) {
items(x) = if ( m.values(x) == null ) {
null
} else {
val columnType = this.currentQuery.get.columnTypes(x)
val columnType = columnsData(x)
this.decoderRegistry.decode(columnType, m.values(x), configuration.charset)
}
x += 1
}

this.currentQuery.get.addRow(items)
processor.processRow(items)
}

override def onRowDescription(m: RowDescriptionMessage) {
this.currentQuery = Option(new MutableResultSet(m.columnDatas))
this.resultProcessor.get.columnTypes = m.columnDatas
this.setColumnDatas(m.columnDatas)
}

Expand All @@ -215,20 +216,15 @@ class PostgreSQLConnection
}

override def onAuthenticationResponse(message: AuthenticationMessage) {

message match {
case m: AuthenticationOkMessage => {
case m: AuthenticationOkMessage =>
log.debug("Successfully logged in to database")
this.authenticated = true
}
case m: AuthenticationChallengeCleartextMessage => {
case m: AuthenticationChallengeCleartextMessage =>
write(this.credential(m))
}
case m: AuthenticationChallengeMD5 => {
case m: AuthenticationChallengeMD5 =>
write(this.credential(m))
}
}

}

override def onNotificationResponse( message : NotificationResponse ) {
Expand Down Expand Up @@ -275,8 +271,8 @@ class PostgreSQLConnection
}

def validateIfItIsReadyForQuery(errorMessage: String) =
if (this.queryPromise.isDefined)
notReadyForQueryError(errorMessage, false)
if (this.resultProcessor.isDefined)
notReadyForQueryError(errorMessage, race = false)

private def validateQuery(query: String) {
this.validateIfItIsReadyForQuery("Can't run query because there is one query pending already")
Expand All @@ -286,30 +282,17 @@ class PostgreSQLConnection
}
}

private def queryPromise: Option[Promise[QueryResult]] = queryPromiseReference.get()

private def setQueryPromise(promise: Promise[QueryResult]) {
if (!this.queryPromiseReference.compareAndSet(None, Some(promise)))
notReadyForQueryError("Can't run query due to a race with another started query", true)
}

private def clearQueryPromise : Option[Promise[QueryResult]] = {
this.queryPromiseReference.getAndSet(None)
private def setResultProcessor(resultProcessor: ResultProcessor) {
if (!this.resultProcessorReference.compareAndSet(None, Some(resultProcessor)))
notReadyForQueryError("Can't run query due to a race with another started query", race = true)
}

private def failQueryPromise(t: Throwable) {
this.clearQueryPromise.foreach { promise =>
log.error("Setting error on future {}", promise)
promise.failure(t)
}
private def clearResultProcessor : ResultProcessor = {
this.resultProcessorReference.getAndSet(None).get
}

private def succeedQueryPromise(result: QueryResult) {
this.queryResult = None
this.currentQuery = None
this.clearQueryPromise.foreach {
_.success(result)
}
private def failQuery(t: Throwable) {
this.clearResultProcessor.failure(t)
}

private def write( message : ClientMessage ) {
Expand All @@ -319,4 +302,64 @@ class PostgreSQLConnection
override def toString: String = {
s"${this.getClass.getSimpleName}{counter=${this.currentCount}}"
}

private sealed trait ResultProcessor {
var columnTypes: IndexedSeq[PostgreSQLColumnData]

def processRow(items: Array[Any])

def complete()

def portalSuspended()

def failure(t: Throwable)

def completeCommand(rowsAffected: Int, statusMessage: String)
}

private class ConnectResultProcessor extends ResultProcessor{
val connectionPromise = Promise[Connection]()
override var columnTypes: IndexedSeq[PostgreSQLColumnData] = _

override def processRow(items: Array[Any]): Unit = {}

override def completeCommand(rowsAffected: Int, statusMessage: String): Unit = {}

override def portalSuspended(): Unit = {}

override def failure(t: Throwable): Unit = {
connectionPromise.failure(t)
}

override def complete(): Unit = {
connectionPromise.success(PostgreSQLConnection.this)
}
}

private class PromiseResultProcessor(val queryPromise : Promise[QueryResult]) extends ResultProcessor {
var currentQuery: Option[MutableResultSet[PostgreSQLColumnData]] = None
var queryResult: Option[QueryResult] = None

def columnTypes : IndexedSeq[PostgreSQLColumnData] = currentQuery.get.columnTypes
def columnTypes_= (columnData: IndexedSeq[PostgreSQLColumnData]) : Unit = {
currentQuery = Some(new MutableResultSet(columnData))
}
override def complete(): Unit = {
queryResult.foreach(queryPromise.success)
}

override def processRow(items: Array[Any]): Unit = {
currentQuery.get.addRow(items)
}

override def failure(t: Throwable): Unit = {
queryPromise.failure(t)
}

override def completeCommand(rowsAffected: Int, statusMessage: String): Unit = {
queryResult = Some(new QueryResult(rowsAffected, statusMessage, currentQuery))
}

override def portalSuspended(): Unit = {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.util.UUID
import com.github.mauricio.async.db.pool.{ConnectionPool, PoolConfiguration}
import com.github.mauricio.async.db.postgresql.exceptions.GenericDatabaseException
import com.github.mauricio.async.db.postgresql.{PostgreSQLConnection, DatabaseTestHelper}
import org.specs2.execute.{Result, Success, AsResult}
import org.specs2.mutable.Specification
import org.specs2.specification.Fixture

object ConnectionPoolSpec {
val Insert = "insert into transaction_test (id) values (?)"
Expand Down Expand Up @@ -60,7 +62,7 @@ class ConnectionPoolSpec extends Specification with DatabaseTestHelper {
}
}

"runs commands for a transaction in a single connection" in {
"runs commands for a transaction in a single connection" ! attempts {_ =>

val id = UUID.randomUUID().toString

Expand All @@ -85,6 +87,14 @@ class ConnectionPoolSpec extends Specification with DatabaseTestHelper {

}

val attemptsCount = 20
val attempts = new Fixture[Int] {
def apply[R : AsResult](f: Int => R) = {
(0 to attemptsCount).foldLeft(Success(): Result) { (res, i) =>
res and AsResult(f(i))
}
}
}
def withPool[R]( fn : (ConnectionPool[PostgreSQLConnection]) => R ) : R = {

val pool = new ConnectionPool( new PostgreSQLConnectionFactory(defaultConfiguration), PoolConfiguration.Default )
Expand Down