forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
attr_tensorflow_graph_sema.swift
65 lines (48 loc) · 2.4 KB
/
attr_tensorflow_graph_sema.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// RUN: %target-typecheck-verify-swift
import TensorFlow
@TensorFlowGraph
func justHandles(_ x: TensorHandle<Int32>) -> TensorHandle<Double> {} // okay
@TensorFlowGraph
func justHandles(_ x: ResourceHandle) -> VariantHandle {} // okay
@TensorFlowGraph
func multiRetHandles(_ x: ResourceHandle, _ y: VariantHandle) -> (TensorHandle<Int32>, TensorHandle<Float>) {} // okay
@TensorFlowGraph
func tensorIn(_ x: Tensor<Float>) -> ResourceHandle {} // okay
@TensorFlowGraph
func tensorOut(_ x: ResourceHandle) -> Tensor<Float> {} // okay
@TensorFlowGraph
func tensors(_ x: Tensor<Float>) -> Tensor<Int32> {} // okay
@TensorFlowGraph
func multiRetTensors(_ x: Tensor<Int32>) -> (Tensor<Double>, Tensor<Float>) {} // okay
// expected-error @+1 {{@TensorFlowGraph cannot be applied to generic functions}}
@TensorFlowGraph
func generic<T>(_ x: Tensor<T>) -> Tensor<T> {}
// expected-error @+1 {{@TensorFlowGraph can only be applied to functions whose parameters and return values are TensorFlow values or aggregates of TensorFlow values}}
@TensorFlowGraph
func wrongInput(_ x: Int32) -> ResourceHandle {}
// expected-error @+1 {{@TensorFlowGraph can only be applied to functions whose parameters and return values are TensorFlow values or aggregates of TensorFlow values}}
@TensorFlowGraph
func wrongOutput(_ x: Tensor<Float>) -> Float {}
enum SomeType {
// expected-error @+1 {{@TensorFlowGraph can only be applied to top-level functions}}
@TensorFlowGraph
func methodNotOkay(_ x: Tensor<Float>) -> Tensor<Float> {}
// expected-error @+1 {{@TensorFlowGraph can only be applied to top-level functions}}
@TensorFlowGraph
static func methodNotOkay2(_ x: Tensor<Float>) -> Tensor<Float> {}
}
let f: @convention(tensorflow) (Tensor<Float>) -> Tensor<Int32> = tensors(_:) // okay
let g: (Tensor<Float>) -> Tensor<Int32> = tensors(_:) // expected-error {{TensorFlow functions cannot be converted to other function types}}
func hof(_ f: @convention(tensorflow) (Tensor<Float>) -> Tensor<Int32>) -> Tensor<Float> {}
_ = hof(tensors) // okay
// Enable these tests when SR-8487 is fixed.
//
// let closure: @convention(tensorflow) (Tensor<Float>) -> Tensor<Int32> = {
// return Tensor<Int32>($0)
// } // okay
//
// hof {
// return Tensor<Int32>($0)
// } // okay
func hofHost(_ f: (Tensor<Float>) -> Tensor<Int32>) -> Tensor<Float> {}
_ = hofHost(tensors) // expected-error {{TensorFlow functions cannot be converted to other function types}}