forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
schema_matching.cpp
771 lines (709 loc) · 26.3 KB
/
schema_matching.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <caffe2/serialize/versions.h>
#include <torch/csrc/jit/frontend/builtin_functions.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/operator_upgraders/utils.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch::jit {
static inline TypePtr unwrapOptional(TypePtr opt_type) {
if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
return unwrapOptional(dyn->fallback());
}
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
return unwrap_list_type->getElementType();
}
return opt_type;
}
static inline bool isIntOrFloatUsedAsList(
const Value* value,
const Argument& arg) {
// Look for int[N] or float[N]
const auto& v_type = value->type();
if (v_type != FloatType::get() && v_type != IntType::get())
return false;
auto arg_type = unwrapOptional(arg.type());
auto list_type = arg_type->cast<ListType>();
return list_type && list_type->getElementType() == v_type && arg.N();
}
/// Returns true if `type` is a Tuple in which all the elements have the
/// same type or if it's a subtype of `list_type_`.
bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
auto list_type = list_type_->castRaw<ListType>();
if (!list_type) {
return false;
}
if (type->isSubtypeOf(*list_type_)) {
return true;
}
if (auto tuple = type->castRaw<TupleType>()) {
return std::all_of(
tuple->elements().begin(),
tuple->elements().end(),
[&](const TypePtr& t) {
// TODO: resolve VarType if necessary
return t->isSubtypeOf(*list_type->getElementType());
});
}
return false;
}
// Applies implicit conversion from value trying to turn it into type
// concrete_type. It succeeds if `return_value->isSubtypeOf(concrete_type)`
Value* tryConvertToType(
const SourceRange& loc,
Graph& graph,
const TypePtr& concrete_type,
Value* value,
bool allow_conversions) {
// treat conversion to Optional[T] as conversions to T
if (OptionalTypePtr op = concrete_type->cast<OptionalType>()) {
if (value->type()->kind() != OptionalType::Kind &&
!value->type()->isSubtypeOf(*NoneType::get())) {
return tryConvertToType(
loc, graph, op->getElementType(), value, allow_conversions);
}
}
// allow temporary, unannotated list literals `[]` to match to arbitrary list
// types
if (value->node()->kind() == prim::EmptyListLiteral &&
concrete_type->cast<ListType>()) {
value = graph
.insertNode(graph.createList(
concrete_type->cast<ListType>()->getElementType(), {}))
->output();
}
if (auto value_tuple = value->type()->cast<TupleType>()) {
// Allow homogeneous tuples to be casted implicitly to lists of appropriate
// types
if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
auto unpacked = createTupleUnpack(value);
auto elem_type =
unwrapOptional(concrete_type)->expectRef<ListType>().getElementType();
value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
}
// inductively apply implicit conversions to tuples
if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
if (!value_tuple->isSubtypeOf(*concrete_tuple) &&
concrete_tuple->elements().size() == value_tuple->elements().size()) {
auto unpacked = createTupleUnpack(value);
std::vector<Value*> converted;
for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
converted.emplace_back(tryConvertToType(
loc,
graph,
concrete_tuple->elements().at(i),
unpacked.at(i),
allow_conversions));
}
value = graph.insertNode(graph.createTuple(converted))->output();
}
}
}
// implicit conversions
if (allow_conversions) {
// Convert tensor or number to concrete int/float types
bool value_isa_tensor = value->type()->isSubtypeOf(*TensorType::get());
bool value_equals_number = *value->type() == *NumberType::get();
bool concrete_float = *concrete_type == *FloatType::get();
bool concrete_complex = *concrete_type == *ComplexType::get();
bool concrete_int = *concrete_type == *IntType::get();
bool concrete_number = *concrete_type == *NumberType::get();
if (value_isa_tensor) {
if (concrete_float) {
value = graph.insert(aten::FloatImplicit, {value}, {}, loc);
} else if (concrete_complex) {
value = graph.insert(aten::ComplexImplicit, {value}, {}, loc);
} else if (concrete_int) {
value = graph.insert(aten::IntImplicit, {value}, {}, loc);
} else if (concrete_number) {
value = graph.insert(aten::ScalarImplicit, {value}, {}, loc);
}
} else if (value_equals_number) {
if (concrete_float) {
value = graph.insert(aten::Float, {value}, {}, loc);
} else if (concrete_complex) {
value = graph.insert(aten::Complex, {value}, {}, loc);
} else if (concrete_int) {
value = graph.insert(aten::Int, {value}, {}, loc);
}
}
// Convert strings to device
if (value->type()->isSubtypeOf(*StringType::get()) &&
concrete_type->isSubtypeOf(*DeviceObjType::get())) {
return graph.insert(aten::device, {value}, {}, loc);
}
}
return value;
}
// Checks if `named_value` can be used as a value for `arg`. If `arg` is a
// VarType, it will be added to the type_env through `matchTypeVariables` as
// the corresponding actual type. If `allow_conversions` is true, implicit
// conversions to the `arg` type may be performed through `tryConvertToType`.
static Value* tryMatchArgument(
const Argument& arg,
Graph& graph,
const SourceRange& loc,
const NamedValue& named_value,
std::ostream* failure_messages,
const std::function<std::ostream&()>& err,
bool allow_conversions,
TypeEnv& type_env) {
Value* value = named_value.value(graph);
// Some functions that take lists of integers or floats for fixed size arrays
// also allow single ints/floats to be passed in their place. The single
// int/float is then repeated to the length of the list
if (isIntOrFloatUsedAsList(value, arg)) {
std::vector<Value*> repeated(*arg.N(), value);
value =
graph.insertNode(graph.createList(value->type(), repeated))->output();
}
// Resolve VarType variables
const MatchTypeReturn matched =
matchTypeVariables(arg.type(), value->type(), type_env);
if (!matched.success()) {
if (failure_messages) {
err() << "Could not match type " << value->type()->repr_str() << " to "
<< arg.type()->repr_str() << " in argument '" << arg.name()
<< "': " << matched.reason() << ".\n";
}
return nullptr;
}
const auto concrete_type = tryEvalTypeVariables(arg.type(), type_env);
if (!concrete_type) {
if (failure_messages) {
err() << "Type variables in type " << arg.type()->repr_str()
<< " could not be inferred from actual type "
<< value->type()->repr_str();
}
return nullptr;
}
// Check if the value can be matched to the arg through any implicit
// conversions
value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
std::stringstream ss;
if (!value->type()->isSubtypeOfExt(
*concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) {
if (failure_messages) {
auto& ostream = err()
<< arg.formatTypeMismatchMsg(value->type()->repr_str());
if (auto pt = value->type()->cast<TensorType>()) {
if (pt->isInferredType()) {
std::string inferred_type_hint;
inferred_type_hint = c10::str(
"Inferred the value for argument '",
arg.name(),
"' to be of type 'Tensor' ",
"because it was not annotated with an explicit type.\n");
ostream << inferred_type_hint;
}
}
if (auto v = value->type()->cast<ListType>()) {
if (v->getElementType()->isSubtypeOf(*TensorType::get())) {
ostream << "Empty lists default to List[Tensor]. Add a variable "
"annotation to the assignment to create an empty list "
"of another type (torch.jit.annotate(List[T, []]) where T "
"is the type of elements in the list for Python 2)\n";
}
}
ostream << ss.str();
}
return nullptr;
}
return value;
}
c10::optional<size_t> findInputWithName(
const std::string& name,
at::ArrayRef<NamedValue> kwargs,
bool is_aten) {
for (const auto i : c10::irange(kwargs.size())) {
// TS doesn't understand that the self argument in function
// scheams is renamed to input for the functional variant
if (is_aten && name == "self" && kwargs[i].name() == "input") {
return i;
}
if (kwargs[i].name() == name) {
return i;
}
}
return c10::nullopt;
}
/// Creates a list with the provided values if each value's type can be matched
/// to an argument with type `elem_type`. If a type in `varargs` does not match
/// `elem_type`, nullptr is returned. This is used for creating lists from
/// varargs so that calls like torch.zeros(1, 2, 3) will be matched to
/// aten::zeros(int[]).
static Value* tryCreateList(
const TypePtr& elem_type,
Graph& graph,
const SourceRange& loc,
at::ArrayRef<NamedValue> varargs,
std::ostream* failure_messages,
const std::function<std::ostream&()>& err,
bool convert_tensor_to_num,
TypeEnv& type_env) {
Argument elem_arg("<varargs>", elem_type);
std::vector<Value*> list_elements;
for (const auto& named_value : varargs) {
// Try to convert named_value to elem_type
Value* matched_value = tryMatchArgument(
/*arg=*/elem_arg,
graph,
loc,
named_value,
failure_messages,
err,
/*allow_conversions=*/convert_tensor_to_num,
type_env);
if (!matched_value) {
return nullptr;
}
list_elements.push_back(matched_value);
}
return graph.insertNode(graph.createList(elem_type, list_elements))->output();
}
// Check if it is possible to convert all the remaining non-kwarg arguments
// to a list. This allows zeros(IntArrayRef sizes) to work with zeros(1, 2) or
// zeros(1)
static bool varargsCanBeUsedAsList(
const FunctionSchema& schema,
size_t arg_index,
const Argument& arg) {
// The arg must be the last one in the arg list that is not a kwarg
bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
schema.arguments()[arg_index + 1].kwarg_only();
auto arg_type = arg.type();
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
arg_type = dyn->fallback();
}
// The formal must be a list
bool argument_is_list = arg_type->kind() == TypeKind::ListType;
// matching varargs of typevar list nyi
bool typevar_list = argument_is_list &&
arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
// it must not be a broadcasting list like int[3],
// otherwise a single int is a valid input
bool arg_is_broadcasting_list = bool(arg.N());
return is_last_argument && argument_is_list && !arg_is_broadcasting_list &&
!typevar_list;
}
bool isBlockListedSchema(const FunctionSchema& schema) {
// Note (@zasdfgbnm):
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964
// Currently JIT does not distinguish ScalarType vs int, so there is really
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to
// hardcode the aten::view.dtype here to block this overload. This blocklist
// should be removed when JIT fully suports ScalarType as its own type.
if (schema.name() == "aten::view" && schema.overload_name() == "dtype") {
return true;
}
// Note (@tugsbayasgalan)
// TorchScript doesn't suport kwargs so this op collides with aten.max.others
// since both of them have 2 Tensor inputs. Since we don't expect users to
// use this op in TS, we just skip it
if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") {
return true;
}
if (schema.name() == "aten::min" && schema.overload_name() == "unary_out") {
return true;
}
return false;
}
static c10::optional<MatchedSchema> tryMatchSchema(
const FunctionSchema& schema,
const SourceRange& loc,
Graph& graph,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
c10::optional<NamedValue> self,
std::ostream* failure_messages,
bool allow_conversions) {
if (isBlockListedSchema(schema)) {
return c10::nullopt;
}
auto err = [&]() -> std::ostream& {
*failure_messages << "\n" << schema << ":\n";
return *failure_messages;
};
// For VarTypes, maps VarType name to actual type as it's used with these
// args
TypeEnv type_env;
std::vector<Value*> positional_inputs;
std::vector<bool> used_kwarg(kwargs.size(), false);
auto schema_namespace = schema.operator_name().getNamespace();
bool is_aten = false;
if (schema_namespace.has_value()) {
if (schema_namespace.value() == "aten") {
is_aten = true;
}
}
// if we finish the loop will we have consumed all arguments?
size_t used_args = 0;
for (const auto schema_i : c10::irange(schema.arguments().size())) {
const auto& arg = schema.arguments()[schema_i];
c10::optional<NamedValue> actual_named_value;
if (arg.name() == "self" && self) {
actual_named_value = self;
self = c10::nullopt;
} else if (!arg.kwarg_only() && used_args < args.size()) {
// Try to convert all the remaining non-kwarg arguments (used_args) to a
// list. Allow zeros(IntArrayRef sizes) to work with zeros(1, 2) or
// zeros(1)
if (allow_conversions && varargsCanBeUsedAsList(schema, schema_i, arg)) {
auto value = args[used_args].value(graph);
const auto& actual_type = value->type();
// The actual cannot already be a list
if (actual_type->kind() != TypeKind::ListType &&
!convertibleToList(actual_type, unwrapOptional(arg.type()))) {
auto formal_type = unwrapOptional(arg.type())
->expectRef<ListType>()
.getElementType();
Value* list = tryCreateList(
formal_type,
graph,
loc,
at::ArrayRef<NamedValue>(args).slice(used_args),
failure_messages,
err,
allow_conversions,
type_env);
if (!list) {
return c10::nullopt;
}
used_args = args.size();
positional_inputs.push_back(list);
continue;
}
}
// Set actual_named_value to the argument and mark the arg position as
// used
actual_named_value = args[used_args];
used_args++;
} else if (
auto kwarg_idx = findInputWithName(arg.name(), kwargs, is_aten)) {
const NamedValue& nv = kwargs[*kwarg_idx];
if (used_kwarg[*kwarg_idx]) {
if (failure_messages) {
err() << "Argument " << nv.name()
<< " specified twice in schema, submit a bug report!\n";
}
return c10::nullopt;
}
used_kwarg[*kwarg_idx] = true;
actual_named_value = nv;
} else if (arg.default_value()) {
// Argument has a default value and no value was provided, so use the
// default
actual_named_value = NamedValue(*arg.default_value());
} else {
if (failure_messages) {
err() << "Argument " << schema.arguments()[schema_i].name()
<< " not provided.\n";
}
return c10::nullopt;
}
// Make sure the actual_named_value found matches the type of arg
Value* positional = tryMatchArgument(
arg,
graph,
loc,
*actual_named_value,
failure_messages,
err,
allow_conversions,
type_env);
if (!positional) {
return c10::nullopt;
}
positional_inputs.push_back(positional);
}
// check for unused self argument
if (self != c10::nullopt) {
if (failure_messages) {
err() << "Provided self argument not used in schema.\n";
}
return c10::nullopt;
}
if (schema.is_vararg()) {
for (; used_args < args.size(); ++used_args) {
positional_inputs.push_back(args[used_args].value(graph));
}
}
// check for unused positional arguments
if (used_args < args.size()) {
if (failure_messages) {
err() << "Expected at most " << used_args << " arguments "
<< "but found " << args.size() << " positional arguments.\n";
}
return c10::nullopt;
}
// check for unused kwargs
for (const auto i : c10::irange(kwargs.size())) {
const auto& nv = kwargs[i];
if (!used_kwarg[i]) {
if (failure_messages) {
if (!schema.argumentIndexWithName(nv.name())) {
err() << "Keyword argument " << nv.name() << " unknown.\n";
} else {
err() << "Keyword argument " << nv.name() << " specified twice.\n";
}
}
return c10::nullopt;
}
}
const auto& returns = schema.returns();
auto return_types = fmap(returns, [&](const Argument& r) {
TypePtr result = tryEvalTypeVariables(r.type(), type_env);
TORCH_INTERNAL_ASSERT(
result, r.type()->repr_str(), " has unbound type variables.");
return result;
});
// Codegen does not support return of namedtuples with undefined field names.
// Therefore, either all or none returns has field names.
bool return_has_field_names =
std::all_of(returns.begin(), returns.end(), [&](const Argument& r) {
return r.name().length() > 0;
});
c10::OptNameList return_field_names = c10::nullopt;
if (return_has_field_names) {
return_field_names =
fmap(returns, [&](const Argument& r) { return r.name(); });
}
// construct the full name of the schema for easier look up
auto schema_name = getFullSchemaName(schema);
return MatchedSchema{
std::move(positional_inputs),
std::move(return_types),
std::move(return_field_names),
schema_name};
}
MatchedSchema matchSchema(
const ::c10::FunctionSchema& schema,
const SourceRange& loc,
Graph& graph,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
const c10::optional<NamedValue>& self) {
std::stringstream failure_messages;
if (auto result = tryMatchSchema(
schema,
loc,
graph,
args,
kwargs,
self,
&failure_messages,
/*allow_conversions=*/true)) {
return *result;
}
throw ErrorReport(loc) << failure_messages.str();
}
static std::string prefixLine(
const std::string& str,
const std::string& prefix) {
std::stringstream ss;
bool was_newline = true;
for (auto c : str) {
if (was_newline)
ss << prefix;
ss.put(c);
was_newline = c == '\n';
}
return ss.str();
}
std::pair<size_t, MatchedSchema> matchSchemas(
const std::vector<const FunctionSchema*>& schemas,
const SourceRange& loc,
Graph& graph,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
const c10::optional<NamedValue>& self,
bool render_errors) {
TORCH_INTERNAL_ASSERT(!schemas.empty());
// if there is only one schema, we do not need to try without conversions
// first. this is faster and puts less dead code in the graph.
if (schemas.size() == 1) {
return std::make_pair(
0, matchSchema(*schemas.at(0), loc, graph, args, kwargs, self));
}
std::stringstream failure_messages;
for (bool allow_conversions : {false, true}) {
// clear previous error messages
failure_messages.str("");
for (const auto i : c10::irange(schemas.size())) {
const auto matched_schema = tryMatchSchema(
*schemas[i],
loc,
graph,
args,
kwargs,
self,
render_errors ? &failure_messages : nullptr,
allow_conversions);
if (matched_schema) {
return std::make_pair(i, *matched_schema);
}
}
}
// we optimistically assume this call will not error, and avoid formatting the
// error strings. If we discover it did error, then we replay it, recording
// the errors.
if (!render_errors) {
return matchSchemas(
schemas, loc, graph, args, kwargs, self, /*render_errors=*/true);
}
throw ErrorReport(loc) << "Arguments for call are not valid.\n"
<< "The following variants are available:\n"
<< prefixLine(failure_messages.str(), " ")
<< "\nThe original call is";
throw ErrorReport(loc) << failure_messages.str();
}
// pack outputs of a function following python rules. If there is a single value
// return a SimpleValue, otherwise pack all the values into a Tuple.
static Value* packOutputs(
Graph& g,
at::ArrayRef<Value*> values,
c10::OptNameList field_names) {
if (values.size() == 1) {
return values[0];
}
std::shared_ptr<FunctionSchema> schema;
TupleTypePtr named_tuple = nullptr;
if (field_names) {
auto types = fmap(values, [](Value* v) { return v->type(); });
named_tuple =
TupleType::createNamed(c10::nullopt, field_names.value(), types);
}
return g.insertNode(g.createTuple(values, named_tuple))->output();
}
// Given a successful match between operator schema and symbol, emit a node
// with the appropriate inputs and outputs.
static Value* emitBuiltinNode(
const MatchedSchema& matched_schema,
const SourceRange& loc,
Graph& graph,
Symbol name,
c10::optional<size_t> version) {
auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
->setSourceRange(loc);
for (auto& ret : matched_schema.return_types) {
n->addOutput()->setType(ret);
}
// assert that we did indeed create an op that has implementation
// otherwise schema and dispatch are not in sync ONLY if the op is up
// to date with the server version
if (!version.has_value() ||
isOpSymbolCurrent(matched_schema.schema_name, version.value())) {
n->getOperation();
} else {
n->setHistoricSchemaName(matched_schema.schema_name);
}
return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
}
std::string getFullSchemaName(const ::c10::FunctionSchema& schema) {
if (!schema.overload_name().empty()) {
return schema.operator_name().name + "." + schema.overload_name();
}
return schema.operator_name().name;
}
// Search for operators matching the provided symbol name and input types.
// If one is found, emit a node to the graph for that operator.
Value* emitBuiltinCall(
const SourceRange& loc,
Graph& graph,
Symbol name,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
const c10::optional<NamedValue>& self) {
const auto& variants = getAllOperatorsFor(name);
const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
// first let's set the graph's version
auto graph_version = graph.get_op_version();
std::stringstream failure_messages;
std::vector<const FunctionSchema*> schemas;
// we append them later to schemas because
// parseSchema returns rvalue which can not
// be casted to const pointer.
std::vector<FunctionSchema> upgrader_schemas;
schemas.reserve(variants.size());
for (const std::shared_ptr<Operator>& op : variants) {
bool found_upgrader = false;
auto op_name = getFullSchemaName(op->schema());
if (graph_version.has_value()) {
auto version_entry = get_operator_version_map().find(op_name);
if (version_entry != get_operator_version_map().end()) {
auto old_schema_entry =
findUpgrader(version_entry->second, graph_version.value());
if (old_schema_entry.has_value()) {
FunctionSchema old_schema =
parseSchema(old_schema_entry.value().old_schema);
upgrader_schemas.push_back(old_schema);
found_upgrader = true;
} else {
if (!isOpCurrentBasedOnUpgraderEntries(
version_entry->second, graph_version.value())) {
TORCH_INTERNAL_ASSERT(false, "Valid upgrader must be present");
}
}
}
}
if (!found_upgrader)
schemas.push_back(&op->schema());
}
// we might have seen old historic
// ops that are deprecated
if (variants.empty()) {
auto oldSchemas =
loadPossibleHistoricOps(name.toQualString(), graph_version);
upgrader_schemas.reserve(oldSchemas.size());
for (const auto& old_schema_entry : oldSchemas) {
FunctionSchema old_schema = parseSchema(old_schema_entry);
upgrader_schemas.emplace_back(old_schema);
}
}
// TODO (tugsuu): make sure this is optimized later
for (const auto& schema : upgrader_schemas) {
schemas.push_back(&schema);
}
for (const auto method : builtin_functions) {
method->ensure_defined();
schemas.push_back(&method->getSchema());
}
// no operators found with the same name, print out similarly named operators
if (schemas.empty()) {
const auto close_symbols = findSimilarOperators(name);
auto error = ErrorReport(loc);
const auto& user_function_name = name.toQualString();
error << "Unknown builtin op: " << user_function_name << ".\n";
if (close_symbols.empty()) {
error
<< "Could not find any similar ops to " << user_function_name
<< ". This op may not exist or may not be currently supported in TorchScript.\n";
} else {
error << "Here are some suggestions: \n";
for (const auto& sym : close_symbols) {
error << "\t" << sym.toQualString() << "\n";
}
error << "\nThe original call is";
}
throw error;
}
auto matched = matchSchemas(schemas, loc, graph, args, kwargs, self);
if (matched.first < variants.size() + upgrader_schemas.size()) {
return emitBuiltinNode(matched.second, loc, graph, name, graph_version);
} else {
auto& fn = *builtin_functions[matched.first - variants.size()];
// we inline builtin calls because they are normally very small
// wrappers and are not useful for keeping around to debug
return insertGraph(
graph, *toGraphFunction(fn).graph(), matched.second.inputs)
.at(0);
}
}
} // namespace torch::jit