From fc5e3f61ecb53d4bf2f3a64eac3a0b41b6820596 Mon Sep 17 00:00:00 2001 From: Anton Zherdev Date: Mon, 10 Aug 2015 21:07:49 +1200 Subject: [PATCH] PostgreSQLConnection refactoring --- .../db/postgresql/PostgreSQLConnection.scala | 201 +++++++++++------- .../postgresql/pool/ConnectionPoolSpec.scala | 12 +- 2 files changed, 133 insertions(+), 80 deletions(-) diff --git a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/PostgreSQLConnection.scala b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/PostgreSQLConnection.scala index 8c58076b..6b5766a1 100644 --- a/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/PostgreSQLConnection.scala +++ b/postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/PostgreSQLConnection.scala @@ -16,23 +16,23 @@ package com.github.mauricio.async.db.postgresql -import com.github.mauricio.async.db.QueryResult -import com.github.mauricio.async.db.column.{ColumnEncoderRegistry, ColumnDecoderRegistry} -import com.github.mauricio.async.db.exceptions.{InsufficientParametersException, ConnectionStillRunningQueryException} +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} + +import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry} +import com.github.mauricio.async.db.exceptions.{ConnectionStillRunningQueryException, InsufficientParametersException} import com.github.mauricio.async.db.general.MutableResultSet import com.github.mauricio.async.db.pool.TimeoutScheduler import com.github.mauricio.async.db.postgresql.codec.{PostgreSQLConnectionDelegate, PostgreSQLConnectionHandler} import com.github.mauricio.async.db.postgresql.column.{PostgreSQLColumnDecoderRegistry, PostgreSQLColumnEncoderRegistry} import com.github.mauricio.async.db.postgresql.exceptions._ +import com.github.mauricio.async.db.postgresql.messages.backend._ +import com.github.mauricio.async.db.postgresql.messages.frontend._ import com.github.mauricio.async.db.util._ -import com.github.mauricio.async.db.{Configuration, Connection} -import java.util.concurrent.atomic.{AtomicLong,AtomicInteger,AtomicReference} -import messages.backend._ -import messages.frontend._ -import scala.Some -import scala.concurrent._ +import com.github.mauricio.async.db.{Configuration, Connection, QueryResult} import io.netty.channel.EventLoopGroup -import java.util.concurrent.CopyOnWriteArrayList + +import scala.concurrent._ object PostgreSQLConnection { final val Counter = new AtomicLong() @@ -70,26 +70,27 @@ class PostgreSQLConnection private val parsedStatements = new scala.collection.mutable.HashMap[String, PreparedStatementHolder]() private var authenticated = false - private val connectionFuture = Promise[Connection]() - - private var recentError = false - private val queryPromiseReference = new AtomicReference[Option[Promise[QueryResult]]](None) - private var currentQuery: Option[MutableResultSet[PostgreSQLColumnData]] = None + private var recentError : Option[Throwable] = None + private val resultProcessorReference = new AtomicReference[Option[ResultProcessor]](None) private var currentPreparedStatement: Option[PreparedStatementHolder] = None private var version = Version(0,0,0) - private var notifyListeners = new CopyOnWriteArrayList[NotificationResponse => Unit]() - - private var queryResult: Option[QueryResult] = None + private val notifyListeners = new CopyOnWriteArrayList[NotificationResponse => Unit]() + private var portalSuspended = false override def eventLoopGroup : EventLoopGroup = group - def isReadyForQuery: Boolean = this.queryPromise.isEmpty + + private def resultProcessor: Option[ResultProcessor] = resultProcessorReference.get() + + def isReadyForQuery: Boolean = this.resultProcessor.isEmpty def connect: Future[Connection] = { + val processor = new ConnectResultProcessor() + setResultProcessor(processor) this.connectionHandler.connect.onFailure { - case e => this.connectionFuture.tryFailure(e) + case e => processor.failure(e) } - this.connectionFuture.future + processor.connectionPromise.future } override def disconnect: Future[Connection] = this.connectionHandler.disconnect.map( c => this ) @@ -103,7 +104,7 @@ class PostgreSQLConnection validateQuery(query) val promise = Promise[QueryResult]() - this.setQueryPromise(promise) + this.setResultProcessor(new PromiseResultProcessor(promise)) write(new QueryMessage(query)) addTimeout(promise,configuration.queryTimeout) @@ -114,18 +115,19 @@ class PostgreSQLConnection validateQuery(query) val promise = Promise[QueryResult]() - this.setQueryPromise(promise) + val processor: PromiseResultProcessor = new PromiseResultProcessor(promise) + this.setResultProcessor(processor) val holder = this.parsedStatements.getOrElseUpdate(query, new PreparedStatementHolder( query, preparedStatementsCounter.incrementAndGet )) if (holder.paramsCount != values.length) { - this.clearQueryPromise + this.clearResultProcessor throw new InsufficientParametersException(holder.paramsCount, values) } this.currentPreparedStatement = Some(holder) - this.currentQuery = Some(new MutableResultSet(holder.columnDatas)) + processor.columnTypes = holder.columnDatas write( if (holder.prepared) new PreparedStatementExecuteMessage(holder.statementId, holder.realQuery, values, this.encoderRegistry) @@ -137,32 +139,29 @@ class PostgreSQLConnection promise.future } - override def onError( exception : Throwable ) { - this.setErrorOnFutures(exception) + override def onError(exception : Throwable) { + recentError = Some(exception) } - def hasRecentError: Boolean = this.recentError + def hasRecentError: Boolean = this.recentError.isDefined private def setErrorOnFutures(e: Throwable) { - this.recentError = true - - log.error("Error on connection", e) - - if (!this.connectionFuture.isCompleted) { - this.connectionFuture.failure(e) - this.disconnect - } - + this.portalSuspended = false this.currentPreparedStatement.map(p => this.parsedStatements.remove(p.query)) this.currentPreparedStatement = None - this.failQueryPromise(e) + this.failQuery(e) } override def onReadyForQuery() { - this.connectionFuture.trySuccess(this) - - this.recentError = false - queryResult.foreach(this.succeedQueryPromise) + if (recentError.isDefined) { + setErrorOnFutures(recentError.get) + this.recentError = None + } else if (portalSuspended) { + portalSuspended = false + resultProcessor.get.portalSuspended() + } else { + this.clearResultProcessor.complete() + } } override def onError(m: ErrorMessage) { @@ -171,12 +170,12 @@ class PostgreSQLConnection val error = new GenericDatabaseException(m) error.fillInStackTrace() - this.setErrorOnFutures(error) + recentError = Some(error) } override def onCommandComplete(m: CommandCompleteMessage) { + resultProcessor.get.completeCommand(m.rowsAffected, m.statusMessage) this.currentPreparedStatement = None - queryResult = Some(new QueryResult(m.rowsAffected, m.statusMessage, this.currentQuery)) } override def onParameterStatus(m: ParameterStatusMessage) { @@ -187,24 +186,26 @@ class PostgreSQLConnection } override def onDataRow(m: DataRowMessage) { - val items = new Array[Any](m.values.size) + val items = new Array[Any](m.values.length) var x = 0 - while ( x < m.values.size ) { + val processor: ResultProcessor = this.resultProcessor.get + val columnsData = processor.columnTypes + while ( x < m.values.length ) { items(x) = if ( m.values(x) == null ) { null } else { - val columnType = this.currentQuery.get.columnTypes(x) + val columnType = columnsData(x) this.decoderRegistry.decode(columnType, m.values(x), configuration.charset) } x += 1 } - this.currentQuery.get.addRow(items) + processor.processRow(items) } override def onRowDescription(m: RowDescriptionMessage) { - this.currentQuery = Option(new MutableResultSet(m.columnDatas)) + this.resultProcessor.get.columnTypes = m.columnDatas this.setColumnDatas(m.columnDatas) } @@ -215,20 +216,15 @@ class PostgreSQLConnection } override def onAuthenticationResponse(message: AuthenticationMessage) { - message match { - case m: AuthenticationOkMessage => { + case m: AuthenticationOkMessage => log.debug("Successfully logged in to database") this.authenticated = true - } - case m: AuthenticationChallengeCleartextMessage => { + case m: AuthenticationChallengeCleartextMessage => write(this.credential(m)) - } - case m: AuthenticationChallengeMD5 => { + case m: AuthenticationChallengeMD5 => write(this.credential(m)) - } } - } override def onNotificationResponse( message : NotificationResponse ) { @@ -275,8 +271,8 @@ class PostgreSQLConnection } def validateIfItIsReadyForQuery(errorMessage: String) = - if (this.queryPromise.isDefined) - notReadyForQueryError(errorMessage, false) + if (this.resultProcessor.isDefined) + notReadyForQueryError(errorMessage, race = false) private def validateQuery(query: String) { this.validateIfItIsReadyForQuery("Can't run query because there is one query pending already") @@ -286,30 +282,17 @@ class PostgreSQLConnection } } - private def queryPromise: Option[Promise[QueryResult]] = queryPromiseReference.get() - - private def setQueryPromise(promise: Promise[QueryResult]) { - if (!this.queryPromiseReference.compareAndSet(None, Some(promise))) - notReadyForQueryError("Can't run query due to a race with another started query", true) - } - - private def clearQueryPromise : Option[Promise[QueryResult]] = { - this.queryPromiseReference.getAndSet(None) + private def setResultProcessor(resultProcessor: ResultProcessor) { + if (!this.resultProcessorReference.compareAndSet(None, Some(resultProcessor))) + notReadyForQueryError("Can't run query due to a race with another started query", race = true) } - private def failQueryPromise(t: Throwable) { - this.clearQueryPromise.foreach { promise => - log.error("Setting error on future {}", promise) - promise.failure(t) - } + private def clearResultProcessor : ResultProcessor = { + this.resultProcessorReference.getAndSet(None).get } - private def succeedQueryPromise(result: QueryResult) { - this.queryResult = None - this.currentQuery = None - this.clearQueryPromise.foreach { - _.success(result) - } + private def failQuery(t: Throwable) { + this.clearResultProcessor.failure(t) } private def write( message : ClientMessage ) { @@ -319,4 +302,64 @@ class PostgreSQLConnection override def toString: String = { s"${this.getClass.getSimpleName}{counter=${this.currentCount}}" } + + private sealed trait ResultProcessor { + var columnTypes: IndexedSeq[PostgreSQLColumnData] + + def processRow(items: Array[Any]) + + def complete() + + def portalSuspended() + + def failure(t: Throwable) + + def completeCommand(rowsAffected: Int, statusMessage: String) + } + + private class ConnectResultProcessor extends ResultProcessor{ + val connectionPromise = Promise[Connection]() + override var columnTypes: IndexedSeq[PostgreSQLColumnData] = _ + + override def processRow(items: Array[Any]): Unit = {} + + override def completeCommand(rowsAffected: Int, statusMessage: String): Unit = {} + + override def portalSuspended(): Unit = {} + + override def failure(t: Throwable): Unit = { + connectionPromise.failure(t) + } + + override def complete(): Unit = { + connectionPromise.success(PostgreSQLConnection.this) + } + } + + private class PromiseResultProcessor(val queryPromise : Promise[QueryResult]) extends ResultProcessor { + var currentQuery: Option[MutableResultSet[PostgreSQLColumnData]] = None + var queryResult: Option[QueryResult] = None + + def columnTypes : IndexedSeq[PostgreSQLColumnData] = currentQuery.get.columnTypes + def columnTypes_= (columnData: IndexedSeq[PostgreSQLColumnData]) : Unit = { + currentQuery = Some(new MutableResultSet(columnData)) + } + override def complete(): Unit = { + queryResult.foreach(queryPromise.success) + } + + override def processRow(items: Array[Any]): Unit = { + currentQuery.get.addRow(items) + } + + override def failure(t: Throwable): Unit = { + queryPromise.failure(t) + } + + override def completeCommand(rowsAffected: Int, statusMessage: String): Unit = { + queryResult = Some(new QueryResult(rowsAffected, statusMessage, currentQuery)) + } + + override def portalSuspended(): Unit = {} + } } diff --git a/postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/pool/ConnectionPoolSpec.scala b/postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/pool/ConnectionPoolSpec.scala index b71ebe65..06b414b6 100644 --- a/postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/pool/ConnectionPoolSpec.scala +++ b/postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/pool/ConnectionPoolSpec.scala @@ -21,7 +21,9 @@ import java.util.UUID import com.github.mauricio.async.db.pool.{ConnectionPool, PoolConfiguration} import com.github.mauricio.async.db.postgresql.exceptions.GenericDatabaseException import com.github.mauricio.async.db.postgresql.{PostgreSQLConnection, DatabaseTestHelper} +import org.specs2.execute.{Result, Success, AsResult} import org.specs2.mutable.Specification +import org.specs2.specification.Fixture object ConnectionPoolSpec { val Insert = "insert into transaction_test (id) values (?)" @@ -60,7 +62,7 @@ class ConnectionPoolSpec extends Specification with DatabaseTestHelper { } } - "runs commands for a transaction in a single connection" in { + "runs commands for a transaction in a single connection" ! attempts {_ => val id = UUID.randomUUID().toString @@ -85,6 +87,14 @@ class ConnectionPoolSpec extends Specification with DatabaseTestHelper { } + val attemptsCount = 20 + val attempts = new Fixture[Int] { + def apply[R : AsResult](f: Int => R) = { + (0 to attemptsCount).foldLeft(Success(): Result) { (res, i) => + res and AsResult(f(i)) + } + } + } def withPool[R]( fn : (ConnectionPool[PostgreSQLConnection]) => R ) : R = { val pool = new ConnectionPool( new PostgreSQLConnectionFactory(defaultConfiguration), PoolConfiguration.Default )