1 __all__ = [
'TransformerException',
'AstPrinter',
'TemplateTransformer',
2 'FixRowMultiplierPosArgTransformer',
'RenameNodesTransformer',
3 'TextEncodingDictTransformer',
'FieldAnnotationTransformer',
4 'SupportedAnnotationsTransformer',
'RangeAnnotationTransformer',
5 'CursorAnnotationTransformer',
'AmbiguousSignatureCheckTransformer',
6 'DefaultValueAnnotationTransformer',
7 'DeclBracketTransformer',
'Pipeline']
14 from ast
import literal_eval
15 from abc
import abstractmethod
17 if sys.version_info > (3, 0):
20 from abc
import ABCMeta
as ABC
23 import TableFunctionsFactory_util
as util
24 import TableFunctionsFactory_node
as tf_node
25 import TableFunctionsFactory_declbracket
as declbracket
41 raise NotImplementedError()
45 raise NotImplementedError()
49 raise NotImplementedError()
53 raise NotImplementedError()
57 raise NotImplementedError()
61 raise NotImplementedError()
65 """Only overload the methods you need"""
68 udtf = copy.copy(udtf_node)
69 udtf.inputs = [arg.accept(self)
for arg
in udtf.inputs]
70 udtf.outputs = [arg.accept(self)
for arg
in udtf.outputs]
72 udtf.templates = [t.accept(self)
for t
in udtf.templates]
73 udtf.annotations = [annot.accept(self)
for annot
in udtf.annotations]
77 c = copy.copy(composed_node)
78 c.inner = [i.accept(self)
for i
in c.inner]
82 arg_node = copy.copy(arg_node)
83 arg_node.type = arg_node.type.accept(self)
84 if arg_node.annotations:
85 arg_node.annotations = [a.accept(self)
for a
in arg_node.annotations]
89 return copy.copy(primitive_node)
92 return copy.copy(template_node)
95 return copy.copy(annotation_node)
99 """Returns a line formatted. Useful for testing"""
102 name = udtf_node.name
103 inputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.inputs])
104 outputs =
", ".
join([arg.accept(self)
for arg
in udtf_node.outputs])
105 annotations =
"| ".
join([annot.accept(self)
for annot
in udtf_node.annotations])
106 sizer =
" | " + udtf_node.sizer.accept(self)
if udtf_node.sizer
else ""
108 annotations =
' | ' + annotations
109 if udtf_node.templates:
110 templates =
", ".
join([t.accept(self)
for t
in udtf_node.templates])
111 return "%s(%s)%s -> %s, %s%s" % (name, inputs, annotations, outputs, templates, sizer)
113 return "%s(%s)%s -> %s%s" % (name, inputs, annotations, outputs, sizer)
117 key = template_node.key
118 types = [
'"%s"' % typ
for typ
in template_node.types]
119 return "%s=[%s]" % (key,
", ".
join(types))
123 key = annotation_node.key
124 value = annotation_node.value
125 if isinstance(value, list):
126 return "%s=[%s]" % (key,
','.
join([v.accept(self)
for v
in value]))
127 return "%s=%s" % (key, value)
131 typ = arg_node.type.accept(self)
132 if arg_node.annotations:
133 ann =
" | ".
join([a.accept(self)
for a
in arg_node.annotations])
134 s =
"%s | %s" % (typ, ann)
140 T = composed_node.inner[0].accept(self)
141 if composed_node.is_array():
143 assert len(composed_node.inner) == 1
145 if composed_node.is_column():
147 assert len(composed_node.inner) == 1
149 if composed_node.is_column_list():
151 assert len(composed_node.inner) == 1
152 return "ColumnList" + T
153 if composed_node.is_output_buffer_sizer():
156 assert len(composed_node.inner) == 1
157 return util.translate_map.get(composed_node.type) +
"<%s>" % (N,)
158 if composed_node.is_cursor():
160 Ts =
", ".
join([i.accept(self)
for i
in composed_node.inner])
161 return "Cursor<%s>" % (Ts)
162 raise ValueError(composed_node)
165 t = primitive_node.type
166 if primitive_node.is_output_buffer_sizer():
168 return util.translate_map.get(t, t) +
"<%d>" % (
169 primitive_node.get_parent(tf_node.ArgNode).arg_pos + 1,
171 return util.translate_map.get(t, t)
175 """Like AstPrinter but returns a node instead of a string
181 vals = kwargs.values()
182 for instance
in itertools.product(*vals):
183 yield dict(zip(keys, instance))
187 """Expand template definition into multiple inputs"""
190 if not udtf_node.templates:
195 d = dict([(node.key, node.types)
for node
in udtf_node.templates])
196 name = udtf_node.name
200 inputs = [input_arg.accept(self)
for input_arg
in udtf_node.inputs]
201 outputs = [output_arg.accept(self)
for output_arg
in udtf_node.outputs]
202 udtf = tf_node.UdtfNode(name, inputs, outputs, udtf_node.annotations,
None, udtf_node.sizer, udtf_node.line)
203 udtfs[str(udtf)] = udtf
206 udtfs = list(udtfs.values())
214 typ = composed_node.type
215 typ = self.mapping_dict.get(typ, typ)
217 inner = [i.accept(self)
for i
in composed_node.inner]
218 return composed_node.copy(typ, inner)
221 typ = primitive_node.type
222 typ = self.mapping_dict.get(typ, typ)
223 return primitive_node.copy(typ)
229 * Fix kUserSpecifiedRowMultiplier without a pos arg
231 t = primitive_node.type
233 if primitive_node.is_output_buffer_sizer():
234 pos = tf_node.PrimitiveNode(str(primitive_node.get_parent(tf_node.ArgNode).arg_pos + 1))
235 node = tf_node.ComposedNode(t, inner=[pos])
238 return primitive_node
244 * Rename nodes using util.translate_map as dictionary
248 t = primitive_node.type
249 return primitive_node.copy(util.translate_map.get(t, t))
255 * Add default_input_id to Column(List)<TextEncodingDict> without one
259 default_input_id =
None
260 for idx, t
in enumerate(udtf_node.inputs):
262 if not isinstance(t.type, tf_node.ComposedNode):
264 if default_input_id
is not None:
266 elif t.type.is_column_text_encoding_dict()
or t.type.is_column_array_text_encoding_dict():
267 default_input_id = tf_node.AnnotationNode(
'input_id',
'args<%s>' % (idx,))
268 elif t.type.is_column_list_text_encoding_dict():
269 default_input_id = tf_node.AnnotationNode(
'input_id',
'args<%s, 0>' % (idx,))
271 for t
in udtf_node.outputs:
272 if isinstance(t.type, tf_node.ComposedNode)
and t.type.is_any_text_encoding_dict():
273 for a
in t.annotations:
274 if a.key ==
'input_id':
277 if default_input_id
is None:
278 raise TypeError(
'Cannot parse line "%s".\n'
279 'Missing TextEncodingDict input?' %
281 t.annotations.append(default_input_id)
290 * Generate fields annotation to Cursor if non-existing
294 for t
in udtf_node.inputs:
296 if not isinstance(t.type, tf_node.ComposedNode):
299 if t.type.is_cursor()
and t.get_annotation(
'fields')
is None:
300 fields = list(tf_node.PrimitiveNode(a.get_annotation(
'name',
'field%s' % i))
for i, a
in enumerate(t.type.inner))
301 t.annotations.append(tf_node.AnnotationNode(
'fields', fields))
309 * Typechecks default value annotations.
313 for t
in udtf_node.inputs:
314 for a
in filter(
lambda x: x.key ==
"default", t.annotations):
315 if not t.type.is_scalar():
317 'Error in function "%s", input annotation \'%s=%s\'. '
318 '\"default\" annotation is only supported for scalar types!'\
319 % (udtf_node.name, a.key, a.value)
321 literal = literal_eval(a.value)
322 lst = [(bool,
'is_boolean_scalar'), (int,
'is_integer_scalar'), (float,
'is_float_scalar'),
323 (str,
'is_string_scalar')]
325 for (cls, mthd)
in lst:
326 if type(literal)
is cls:
327 assert isinstance(t, tf_node.ArgNode)
328 m = getattr(t.type, mthd)
331 'Error in function "%s", input annotation \'%s=%s\'. '
332 'Argument is of type "%s" but value type was inferred as "%s".'
333 % (udtf_node.name, a.key, a.value, t.type.type,
type(literal).__name__))
341 * Checks for supported annotations in a UDTF
344 for t
in udtf_node.inputs:
345 for a
in t.annotations:
346 if a.key
not in util.SupportedAnnotations:
348 for t
in udtf_node.outputs:
349 for a
in t.annotations:
350 if a.key
not in util.SupportedAnnotations:
352 for annot
in udtf_node.annotations:
353 if annot.key
not in util.SupportedFunctionAnnotations:
355 if annot.value.lower()
in [
'enable',
'on',
'1',
'true']:
357 elif annot.value.lower()
in [
'disable',
'off',
'0',
'false']:
364 * A UDTF declaration is ambiguous if two or more ColumnLists are adjacent
366 func__0(ColumnList<T> X, ColumnList<T> Z) -> Column<U>
367 func__1(ColumnList<T> X, Column<T> Y, ColumnList<T> Z) -> Column<U>
368 The first ColumnList ends up consuming all of the arguments leaving a single
369 one for the last ColumnList. In other words, Z becomes a Column
374 lst: list[list[Node]]
377 for i
in range(len(l)):
378 if not l[i].is_column_list():
385 for j
in range(i+1, len(l)):
387 if l[j].is_column()
and l[j].is_column_of(T):
389 elif l[j].is_column_list()
and l[j].is_column_list_of(T):
390 msg = (
'%s signature is ambiguous as there are two '
391 'ColumnList with the same subtype in the same '
392 'group.') % (udtf_name)
393 if udtf_name
not in [
'ct_overload_column_list_test2__cpu_template']:
395 warnings.warn(msg, TransformerWarning)
402 for arg
in udtf_node.inputs:
404 if isinstance(s, list):
409 if cursor
or len(lst) == 0:
421 if composed_node.is_cursor():
422 return [i.accept(self)
for i
in composed_node.inner]
427 return arg_node.type.accept(self)
432 * Append require annotation if range is used
435 for ann
in arg_node.annotations:
436 if ann.key ==
'range':
437 name = arg_node.get_annotation(
'name')
444 value =
'"{lo} <= {name} && {name} <= {hi}"'.format(lo=lo, hi=hi, name=name)
447 arg_node.set_annotation(
'require', value)
453 * Move a "require" annotation from inside a cursor to the cursor
457 if arg_node.type.is_cursor():
458 for inner
in arg_node.type.inner:
459 for ann
in inner.annotations:
460 if ann.key ==
'require':
461 arg_node.annotations.append(ann)
468 name = udtf_node.name
470 input_annotations = []
472 output_annotations = []
473 function_annotations = []
474 sizer = udtf_node.sizer
476 for i
in udtf_node.inputs:
477 decl = i.accept(self)
479 input_annotations.append(decl.annotations)
481 for o
in udtf_node.outputs:
482 decl = o.accept(self)
483 outputs.append(decl.type)
484 output_annotations.append(decl.annotations)
486 for annot
in udtf_node.annotations:
487 annot = annot.accept(self)
488 function_annotations.append(annot)
490 return util.Signature(name, inputs, outputs, input_annotations, output_annotations, function_annotations, sizer)
493 t = arg_node.type.accept(self)
494 anns = [a.accept(self)
for a
in arg_node.annotations]
495 return declbracket.Declaration(t, anns)
498 typ = util.translate_map.get(composed_node.type, composed_node.type)
499 inner = [i.accept(self)
for i
in composed_node.inner]
500 if composed_node.is_cursor():
501 inner = list(map(
lambda x: x.apply_column(), inner))
502 return declbracket.Bracket(typ, args=tuple(inner))
503 elif composed_node.is_output_buffer_sizer():
504 return declbracket.Bracket(typ, args=tuple(inner))
506 return declbracket.Bracket(typ + str(inner[0]))
509 t = primitive_node.type
510 return declbracket.Bracket(t)
513 key = annotation_node.key
514 value = annotation_node.value
523 if not isinstance(ast_list, list):
524 ast_list = [ast_list]
527 ast_list = [ast.accept(c())
for ast
in ast_list]
528 ast_list = itertools.chain.from_iterable(
529 map(
lambda x: x
if isinstance(x, list)
else [x], ast_list))
531 return list(ast_list)
size_t append(FILE *f, const size_t size, const int8_t *buf)
Appends the specified number of bytes to the end of the file f from buf.