Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nsc committed Jul 9, 2023
1 parent 94e4a9e commit ecffd05
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 292 deletions.
52 changes: 15 additions & 37 deletions SwiftTLS/Sources/Log.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,51 +21,24 @@ class LoggingDateFormatter : DateFormatter
}
}

private var threadNumber = 0
private var threadNumberDict: [Thread:Int] = [:]
private let threadNumberQueue = DispatchQueue(label: "threadNumber")

extension Thread {
var number: Int {
var number = 0
let thread = self
threadNumberQueue.sync {
if let n = threadNumberDict[thread] {
number = n
return
}

threadNumber += 1

threadNumberDict[thread] = threadNumber

number = threadNumber
}

return number
}

func removeThreadNumber() {
let thread = self
threadNumberQueue.async {
threadNumberDict.removeValue(forKey: thread)
public class Log
{
@TaskLocal static var connectionNumber: Int?
public static func withConnectionNumber<R>(_ n: Int, _ handler: () async throws -> R) async rethrows -> R {
try await $connectionNumber.withValue(n) {
try await handler()
}
}
}

class Log
{
var enabled: Bool = true
private let formatter = LoggingDateFormatter()
fileprivate let formatter = LoggingDateFormatter()
private let logFile: FileHandle = FileHandle(fileDescriptor: 1)
private let logQueue = DispatchQueue(label: "org.swifttls.logging")

func log(_ message: @autoclosure () -> String, file: StaticString, line: UInt, time: Date) {
func log(_ message: @autoclosure () -> String, file: StaticString, line: UInt, prefixString: String = "") {
if enabled {
let threadNumber = Thread.current.number

logQueue.sync {
let line = "\(formatter.string(from: time)) (~\(threadNumber)): \(message())\n"
let line = "\(prefixString)\(message())\n"
let utf8 = Data(line.utf8)

logFile.write(utf8)
Expand All @@ -76,7 +49,12 @@ class Log

private let logger = Log()
public func log(_ message: @autoclosure () -> String, file: StaticString = #file, line: UInt = #line) {
logger.log(message(), file: file, line: line, time: Date())
var prefixString = "\(logger.formatter.string(from: Date())) "
if let n = Log.connectionNumber {
prefixString += "~\(n): "
}

logger.log(message(), file: file, line: line, prefixString: prefixString)
}

public func TLSEnableLogging(_ v: Bool) {
Expand Down
24 changes: 13 additions & 11 deletions SwiftTLS/Sources/TLS/TLSClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,21 @@ extension TLSClient : ClientSocketProtocol

public func connect(hostname: String, port: UInt16 = 443) async throws
{
if let address = IPv4Address.addressWithString(hostname, port: port) {
var hostNameAndPort = hostname
if port != 443 {
hostNameAndPort = "\(hostname):\(port)"
await ConnectionNumber.increase()
try await Log.withConnectionNumber(ConnectionNumber.value) {
if let address = IPv4Address.addressWithString(hostname, port: port) {
var hostNameAndPort = hostname
if port != 443 {
hostNameAndPort = "\(hostname):\(port)"
}
self.serverNames = [hostNameAndPort]

try await connect(address)
}
else {
throw TLSError.error("Error: Could not resolve host \(hostname)")
}
self.serverNames = [hostNameAndPort]

try await connect(address)
}
else {
throw TLSError.error("Error: Could not resolve host \(hostname)")
}

}

public enum EarlyDataState {
Expand Down
93 changes: 59 additions & 34 deletions SwiftTLS/Sources/TLS/TLSServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,20 @@ extension TLSServer : ServerSocketProtocol
{
let clientSocket = try await self.serverSocket.acceptConnection() as! TCPSocket
let clientTLSSocket = TLSServer(configuration: self.configuration, context: self.context)

try await BigInt.withContext { _ in
clientTLSSocket.socket = clientSocket
clientTLSSocket.signer = self.signer
clientTLSSocket.configuration = self.configuration
clientTLSSocket.recordLayer.dataProvider = clientSocket
clientTLSSocket.context = self.context

clientTLSSocket.earlyDataResponseHandler = earlyDataResponseHandler

try await clientTLSSocket._acceptConnection()

await ConnectionNumber.increase()
try await Log.withConnectionNumber(ConnectionNumber.value) {
try await BigInt.withContext { _ in
clientTLSSocket.socket = clientSocket
clientTLSSocket.signer = self.signer
clientTLSSocket.configuration = self.configuration
clientTLSSocket.recordLayer.dataProvider = clientSocket
clientTLSSocket.context = self.context

clientTLSSocket.earlyDataResponseHandler = earlyDataResponseHandler

try await clientTLSSocket._acceptConnection()
}
}

return clientTLSSocket
Expand All @@ -137,37 +140,59 @@ extension TLSServer : ServerSocketProtocol
case error(Error)
case client(TLSConnection)
}

public func acceptConnection(withEarlyDataResponseHandler earlyDataResponseHandler: EarlyDataResponseHandler?, completionHandler: @escaping (AcceptConnectionResult) async -> ()) async throws
{
let clientSocket = try await serverSocket.acceptConnection() as! TCPSocket

Task {
try await BigInt.withContext { _ in
if let address = clientSocket.peerName {
log("Connection from \(address)")
}

let clientTLSSocket = TLSServer(configuration: self.configuration, context: self.context)
clientTLSSocket.socket = clientSocket
clientTLSSocket.signer = self.signer
clientTLSSocket.configuration = self.configuration
clientTLSSocket.recordLayer.dataProvider = clientSocket
clientTLSSocket.context = self.context

clientTLSSocket.earlyDataResponseHandler = earlyDataResponseHandler

do {
try await clientTLSSocket._acceptConnection()
} catch let error {
await completionHandler(.error(error))
await ConnectionNumber.increase()
try await Log.withConnectionNumber(ConnectionNumber.value) {
try await BigInt.withContext { _ in
if let address = clientSocket.peerName {
log("Connection from \(address)")
}

let clientTLSSocket = TLSServer(configuration: self.configuration, context: self.context)
clientTLSSocket.socket = clientSocket
clientTLSSocket.signer = self.signer
clientTLSSocket.configuration = self.configuration
clientTLSSocket.recordLayer.dataProvider = clientSocket
clientTLSSocket.context = self.context

clientTLSSocket.earlyDataResponseHandler = earlyDataResponseHandler

do {
try await clientTLSSocket._acceptConnection()
} catch let error {
await completionHandler(.error(error))
}

await completionHandler(.client(clientTLSSocket))
}

await completionHandler(.client(clientTLSSocket))

Thread.current.removeThreadNumber()
}
}
}

}

actor ConnectionNumber {
func increase() {
_value += 1
}

var value: Int { _value }
private var _value = 0

static let shared = ConnectionNumber()

static func increase() async {
await shared.increase()
}

static var value: Int {
get async {
await shared.value
}
}
}
59 changes: 32 additions & 27 deletions SwiftTLSTool/client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,50 @@ func connectTo(host : String, port : UInt16 = 443, supportedVersions: [TLSProtoc

// Connect twice to test session resumption
var context: TLSClientContext? = nil
var connectionNumber = 0
try await BigInt.withContext { _ in
var client: TLSClient
for _ in 0..<2 {
do {
print("Connecting to \(host):\(port)")
client = TLSClient(configuration: configuration, context: context)

let requestData = [UInt8]("GET / HTTP/1.1\r\nHost: \(host)\r\nUser-Agent: SwiftTLS\r\nConnection: Close\r\n\r\n".utf8)
try await client.connect(hostname: host, port: port, withEarlyData: Data(requestData))

let earlyDataState = client.earlyDataState
print("Early data: \(earlyDataState)")

if context == nil {
context = client.context as? TLSClientContext
}

print("Connection established using cipher suite \(client.cipherSuite!)")

if earlyDataState != .accepted {
try await client.write(requestData)
}

while true {
let data = try await client.read(count: 4096)
if data.count == 0 {
break

connectionNumber += 1
try await Log.withConnectionNumber(connectionNumber) {

let requestData = [UInt8]("GET / HTTP/1.1\r\nHost: \(host)\r\nUser-Agent: SwiftTLS\r\nConnection: Close\r\n\r\n".utf8)
try await client.connect(hostname: host, port: port, withEarlyData: Data(requestData))

let earlyDataState = client.earlyDataState
print("Early data: \(earlyDataState)")

if context == nil {
context = client.context as? TLSClientContext
}

_ = data.withUnsafeBytes { buffer in
write(1, buffer.baseAddress, buffer.count)

print("Connection established using cipher suite \(client.cipherSuite!)")

if earlyDataState != .accepted {
try await client.write(requestData)
}

while true {
let data = try await client.read(count: 4096)
if data.count == 0 {
break
}

_ = data.withUnsafeBytes { buffer in
write(1, buffer.baseAddress, buffer.count)
}

break
}

break
}
}
catch (let error) {
await client.close()

print("Error: \(error)")
}
}
Expand Down
7 changes: 0 additions & 7 deletions Tests/LinuxMain.swift

This file was deleted.

7 changes: 4 additions & 3 deletions Tests/SwiftTLSTests/SocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class SocketTests: XCTestCase {
let client = TCPSocket()
do {
try await client.connect(address)

await self.waitForExpectations(timeout: 50.0, handler: { (error : Error?) -> Void in

// await self.fulfillment(of: [expectation], timeout: 50)
await waitForExpectations(timeout: 50.0, handler: { (error : Error?) -> Void in
})
}
}
catch let error as SocketError {
print("\(error)")
XCTFail()
Expand Down
4 changes: 1 addition & 3 deletions Tests/SwiftTLSTests/TLSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ class TLSTests: XCTestCase {
}

try await serverSideClientSocket?.write([4,5,6])

await serverSideClientSocket?.close()
}

await server.close()
} catch(let error) {
XCTFail("\(error)")
Expand Down
Loading

0 comments on commit ecffd05

Please sign in to comment.