1: <?php
2:
3: declare(strict_types=1);
4:
5: namespace Atk4\Data\Persistence;
6:
7: use Atk4\Data\Exception;
8: use Atk4\Data\Field;
9: use Atk4\Data\Field\SqlExpressionField;
10: use Atk4\Data\Model;
11: use Atk4\Data\Persistence;
12: use Atk4\Data\Persistence\Sql\Connection;
13: use Atk4\Data\Persistence\Sql\Exception as SqlException;
14: use Atk4\Data\Persistence\Sql\Expression;
15: use Atk4\Data\Persistence\Sql\Expressionable;
16: use Atk4\Data\Persistence\Sql\Query;
17: use Atk4\Data\Reference\HasOneSql;
18: use Doctrine\DBAL\Connection as DbalConnection;
19: use Doctrine\DBAL\Driver\Connection as DbalDriverConnection;
20: use Doctrine\DBAL\Platforms\AbstractPlatform;
21: use Doctrine\DBAL\Platforms\OraclePlatform;
22: use Doctrine\DBAL\Platforms\PostgreSQLPlatform;
23: use Doctrine\DBAL\Platforms\SQLServerPlatform;
24:
25: class Sql extends Persistence
26: {
27: use Sql\BinaryTypeCompatibilityTypecastTrait;
28:
29: public const HOOK_INIT_SELECT_QUERY = self::class . '@initSelectQuery';
30: public const HOOK_BEFORE_INSERT_QUERY = self::class . '@beforeInsertQuery';
31: public const HOOK_AFTER_INSERT_QUERY = self::class . '@afterInsertQuery';
32: public const HOOK_BEFORE_UPDATE_QUERY = self::class . '@beforeUpdateQuery';
33: public const HOOK_AFTER_UPDATE_QUERY = self::class . '@afterUpdateQuery';
34: public const HOOK_BEFORE_DELETE_QUERY = self::class . '@beforeDeleteQuery';
35: public const HOOK_AFTER_DELETE_QUERY = self::class . '@afterDeleteQuery';
36:
37: /** @var Connection */
38: private $_connection;
39:
40: /** @var array<mixed> Default class when adding new field. */
41: protected $_defaultSeedAddField; // no custom seed needed
42:
43: /** @var array<mixed> Default class when adding Expression field. */
44: protected $_defaultSeedAddExpression = [SqlExpressionField::class];
45:
46: /** @var array<mixed> Default class when adding hasOne field. */
47: protected $_defaultSeedHasOne = [HasOneSql::class];
48:
49: /** @var array<mixed> Default class when adding hasMany field. */
50: protected $_defaultSeedHasMany; // no custom seed needed
51:
52: /** @var array<mixed> Default class when adding join. */
53: protected $_defaultSeedJoin = [Sql\Join::class];
54:
55: /**
56: * @param Connection|string|array<string, string>|DbalConnection|DbalDriverConnection $connection
57: * @param string $user
58: * @param string $password
59: * @param array<string, mixed> $defaults
60: */
61: public function __construct($connection, $user = null, $password = null, $defaults = [])
62: {
63: if ($connection instanceof Connection) {
64: $this->_connection = $connection;
65:
66: return;
67: }
68:
69: // attempt to connect
70: $this->_connection = Connection::connect(
71: $connection,
72: $user,
73: $password,
74: $defaults
75: );
76: }
77:
78: public function getConnection(): Connection
79: {
80: return $this->_connection;
81: }
82:
83: #[\Override]
84: public function disconnect(): void
85: {
86: parent::disconnect();
87:
88: $this->_connection = null; // @phpstan-ignore-line
89: }
90:
91: #[\Override]
92: public function atomic(\Closure $fx)
93: {
94: return $this->getConnection()->atomic($fx);
95: }
96:
97: #[\Override]
98: public function getDatabasePlatform(): AbstractPlatform
99: {
100: return $this->getConnection()->getDatabasePlatform();
101: }
102:
103: #[\Override]
104: public function add(Model $model, array $defaults = []): void
105: {
106: $defaults = array_merge([
107: '_defaultSeedAddField' => $this->_defaultSeedAddField,
108: '_defaultSeedAddExpression' => $this->_defaultSeedAddExpression,
109: '_defaultSeedHasOne' => $this->_defaultSeedHasOne,
110: '_defaultSeedHasMany' => $this->_defaultSeedHasMany,
111: '_defaultSeedJoin' => $this->_defaultSeedJoin,
112: ], $defaults);
113:
114: parent::add($model, $defaults);
115:
116: if ($model->table === null) {
117: throw (new Exception('Property $table must be specified for a model'))
118: ->addMoreInfo('model', $model);
119: }
120:
121: // when we work without table, we can't have any IDs
122: if ($model->table === false) {
123: $model->removeField($model->idField);
124: $model->addExpression($model->idField, ['expr' => '-1', 'type' => 'integer']);
125: }
126: }
127:
128: #[\Override]
129: protected function initPersistence(Model $model): void
130: {
131: $model->addMethod('expr', static function (Model $m, ...$args) {
132: return $m->getPersistence()->expr($m, ...$args);
133: });
134: $model->addMethod('dsql', static function (Model $m, ...$args) {
135: return $m->getPersistence()->dsql($m, ...$args); // @phpstan-ignore-line
136: });
137: $model->addMethod('exprNow', static function (Model $m, ...$args) {
138: return $m->getPersistence()->exprNow($m, ...$args);
139: });
140: }
141:
142: /**
143: * Creates new Expression object from expression string.
144: *
145: * @param array<int|string, mixed> $arguments
146: */
147: public function expr(Model $model, string $template, array $arguments = []): Expression
148: {
149: preg_replace_callback(
150: '~(?!\[\w*\])' . Expression::QUOTED_TOKEN_REGEX . '\K|\[\w*\]|\{\w*\}~',
151: static function ($matches) use ($model, &$arguments) {
152: if ($matches[0] === '') {
153: return '';
154: }
155:
156: $identifier = substr($matches[0], 1, -1);
157: if ($identifier !== '' && !isset($arguments[$identifier])) {
158: $arguments[$identifier] = $model->getField($identifier);
159: }
160:
161: return $matches[0];
162: },
163: $template
164: );
165:
166: return $this->getConnection()->expr($template, $arguments);
167: }
168:
169: /**
170: * Creates new Query object with current time expression.
171: */
172: public function exprNow(int $precision = null): Expression
173: {
174: return $this->getConnection()->dsql()->exprNow($precision);
175: }
176:
177: /**
178: * Creates new Query object.
179: */
180: public function dsql(): Query
181: {
182: return $this->getConnection()->dsql();
183: }
184:
185: /**
186: * Initializes base query for model $m.
187: */
188: public function initQuery(Model $model): Query
189: {
190: $query = $this->dsql();
191:
192: if ($model->table) {
193: $query->table(
194: is_object($model->table) ? $model->table->action('select') : $model->table,
195: $model->tableAlias ?? (is_object($model->table) ? '_tm' : null)
196: );
197: }
198:
199: $this->initWithCursors($model, $query);
200:
201: return $query;
202: }
203:
204: public function initWithCursors(Model $model, Query $query): void
205: {
206: foreach ($model->cteModels as $withAlias => ['model' => $withModel, 'recursive' => $withRecursive]) {
207: $subQuery = $withModel->action('select');
208: $query->with($subQuery, $withAlias, null, $withRecursive);
209: }
210: }
211:
212: /**
213: * Adds Field in Query.
214: */
215: public function initField(Query $query, Field $field): void
216: {
217: $query->field($field, $field->useAlias() ? $field->shortName : null);
218: }
219:
220: /**
221: * Adds model fields in Query.
222: *
223: * @param array<int, string>|null $fields
224: */
225: public function initQueryFields(Model $model, Query $query, array $fields = null): void
226: {
227: // init fields
228: if ($fields !== null) {
229: // set of fields is strictly defined, so we will ignore even system fields
230: foreach ($fields as $fieldName) {
231: $this->initField($query, $model->getField($fieldName));
232: }
233: } elseif ($model->onlyFields !== null) {
234: $addedFields = [];
235:
236: // add requested fields first
237: foreach ($model->onlyFields as $fieldName) {
238: $field = $model->getField($fieldName);
239: if ($field->neverPersist) {
240: continue;
241: }
242: $this->initField($query, $field);
243: $addedFields[$fieldName] = true;
244: }
245:
246: // now add system fields, if they were not added
247: foreach ($model->getFields() as $fieldName => $field) {
248: if ($field->neverPersist) {
249: continue;
250: }
251: if ($field->system && !isset($addedFields[$fieldName])) {
252: $this->initField($query, $field);
253: }
254: }
255: } else {
256: foreach ($model->getFields() as $fieldName => $field) {
257: if ($field->neverPersist) {
258: continue;
259: }
260: $this->initField($query, $field);
261: }
262: }
263: }
264:
265: /**
266: * Will set limit defined inside $m onto query $q.
267: */
268: protected function setLimitOrder(Model $model, Query $query): void
269: {
270: // set limit
271: if ($model->limit[0] !== null || $model->limit[1] !== 0) {
272: $query->limit($model->limit[0] ?? \PHP_INT_MAX, $model->limit[1]);
273: }
274:
275: // set order
276: foreach ($model->order as $order) {
277: $isDesc = strtolower($order[1]) === 'desc';
278:
279: if ($order[0] instanceof Expressionable) {
280: $query->order($order[0], $isDesc);
281: } else {
282: $query->order($model->getField($order[0]), $isDesc);
283: }
284: }
285: }
286:
287: /**
288: * Will apply model scope/conditions onto $query.
289: */
290: public function initQueryConditions(Model $model, Query $query): void
291: {
292: $this->_initQueryConditions($query, $model->getModel(true)->scope());
293:
294: // add entity ID to scope to allow easy traversal
295: if ($model->isEntity() && $model->idField && $model->getId() !== null) {
296: $query->group($model->getField($model->idField));
297: $this->fixMssqlOracleMissingFieldsInGroup($model, $query);
298: $query->having($model->getField($model->idField), $model->getId());
299: }
300: }
301:
302: private function fixMssqlOracleMissingFieldsInGroup(Model $model, Query $query): void
303: {
304: if ($this->getDatabasePlatform() instanceof SQLServerPlatform
305: || $this->getDatabasePlatform() instanceof OraclePlatform) {
306: $isIdFieldInGroup = false;
307: foreach ($query->args['group'] ?? [] as $v) {
308: if ($model->idField && $v === $model->getField($model->idField)) {
309: $isIdFieldInGroup = true;
310:
311: break;
312: }
313: }
314:
315: if ($isIdFieldInGroup) {
316: foreach ($query->args['field'] ?? [] as $field) {
317: if ($field instanceof Field) {
318: $query->group($field);
319: }
320: }
321: }
322: }
323: }
324:
325: private function _initQueryConditions(Query $query, Model\Scope\AbstractScope $condition = null): void
326: {
327: if (!$condition->isEmpty()) {
328: // peel off the single nested scopes to convert (((field = value))) to field = value
329: $condition = $condition->simplify();
330:
331: // simple condition
332: if ($condition instanceof Model\Scope\Condition) {
333: $query->where(...$condition->toQueryArguments());
334: }
335:
336: // nested conditions
337: if ($condition instanceof Model\Scope) {
338: $expression = $condition->isOr() ? $query->orExpr() : $query->andExpr();
339:
340: foreach ($condition->getNestedConditions() as $nestedCondition) {
341: $this->_initQueryConditions($expression, $nestedCondition);
342: }
343:
344: $query->where($expression);
345: }
346: }
347: }
348:
349: /**
350: * @param array<mixed> $args
351: *
352: * @return Query
353: */
354: public function action(Model $model, string $type, array $args = [])
355: {
356: switch ($type) {
357: case 'select':
358: $query = $this->initQuery($model);
359: $this->initQueryFields($model, $query, $args[0] ?? null);
360: $this->initQueryConditions($model, $query);
361: $this->setLimitOrder($model, $query);
362: $model->hook(self::HOOK_INIT_SELECT_QUERY, [$query, $type]);
363:
364: return $query;
365: case 'count':
366: $query = $this->initQuery($model);
367: $this->initQueryConditions($model, $query);
368: $model->hook(self::HOOK_INIT_SELECT_QUERY, [$query, $type]);
369:
370: return $query->reset('field')->field('count(*)', $args['alias'] ?? null);
371: case 'exists':
372: $query = $this->initQuery($model);
373: $this->initQueryConditions($model, $query);
374: $model->hook(self::HOOK_INIT_SELECT_QUERY, [$query, $type]);
375:
376: return $query->exists();
377: case 'field':
378: if (!isset($args[0])) {
379: throw (new Exception('This action requires one argument with field name'))
380: ->addMoreInfo('action', $type);
381: }
382: $field = $args[0];
383: if (is_string($field)) {
384: $field = $model->getField($field);
385: }
386:
387: $query = $this->action($model, 'select', [[]]);
388:
389: if (isset($args['alias'])) {
390: $query->reset('field')->field($field, $args['alias']);
391: } elseif ($field instanceof SqlExpressionField) {
392: $query->reset('field')->field($field, $field->shortName);
393: } else {
394: $query->reset('field')->field($field);
395: }
396: $this->fixMssqlOracleMissingFieldsInGroup($model, $query);
397:
398: if ($model->isEntity() && $model->isLoaded()) {
399: $idRaw = $this->typecastSaveField($model->getField($model->idField), $model->getId());
400: $query->where($model->getField($model->idField), $idRaw);
401: }
402:
403: return $query;
404: case 'fx':
405: case 'fx0':
406: if (!isset($args[0]) || !isset($args[1])) {
407: throw (new Exception('fx action needs 2 arguments, eg: ["sum", "amount"]'))
408: ->addMoreInfo('action', $type);
409: }
410: [$fx, $field] = $args;
411: if (is_string($field)) {
412: $field = $model->getField($field);
413: }
414:
415: $query = $this->action($model, 'select', [[]]);
416:
417: if ($fx === 'concat') {
418: $expr = $query->groupConcat($field, $args['concatSeparator']);
419: } else {
420: $expr = $query->expr(
421: $type === 'fx'
422: ? $fx . '([])'
423: : 'coalesce(' . $fx . '([]), 0)',
424: [$field]
425: );
426: }
427:
428: if (isset($args['alias'])) {
429: $query->reset('field')->field($expr, $args['alias']);
430: } elseif ($field instanceof SqlExpressionField) {
431: $query->reset('field')->field($expr, $fx . '_' . $field->shortName);
432: } else {
433: $query->reset('field')->field($expr);
434: }
435: $this->fixMssqlOracleMissingFieldsInGroup($model, $query);
436:
437: return $query;
438: default:
439: throw (new Exception('Unsupported action mode'))
440: ->addMoreInfo('type', $type);
441: }
442: }
443:
444: #[\Override]
445: public function tryLoad(Model $model, $id): ?array
446: {
447: $model->assertIsModel();
448:
449: $noId = $id === self::ID_LOAD_ONE || $id === self::ID_LOAD_ANY;
450:
451: $query = $model->action('select');
452:
453: if (!$noId) {
454: if (!$model->idField) {
455: throw (new Exception('Unable to load by "id" when Model->idField is not defined'))
456: ->addMoreInfo('id', $id);
457: }
458:
459: $idRaw = $this->typecastSaveField($model->getField($model->idField), $id);
460: $query->where($model->getField($model->idField), $idRaw);
461: }
462: $query->limit(
463: min($id === self::ID_LOAD_ANY ? 1 : 2, $query->args['limit']['cnt'] ?? \PHP_INT_MAX),
464: $query->args['limit']['shift'] ?? null
465: );
466:
467: // execute action
468: try {
469: $rowsRaw = $query->getRows();
470: if (count($rowsRaw) === 0) {
471: return null;
472: } elseif (count($rowsRaw) !== 1) {
473: throw (new Exception('Ambiguous conditions, more than one record can be loaded'))
474: ->addMoreInfo('model', $model)
475: ->addMoreInfo('idField', $model->idField)
476: ->addMoreInfo('id', $noId ? null : $id);
477: }
478: $data = $this->typecastLoadRow($model, $rowsRaw[0]);
479: } catch (SqlException $e) {
480: throw (new Exception('Unable to load due to query error', 0, $e))
481: ->addMoreInfo('model', $model)
482: ->addMoreInfo('scope', $model->scope()->toWords());
483: }
484:
485: if ($model->idField && !isset($data[$model->idField])) {
486: // TODO detect even an ID change here!
487: throw (new Exception('Model uses "idField" but it was not available in the database'))
488: ->addMoreInfo('model', $model)
489: ->addMoreInfo('idField', $model->idField)
490: ->addMoreInfo('id', $noId ? null : $id)
491: ->addMoreInfo('data', $data);
492: }
493:
494: return $data;
495: }
496:
497: /**
498: * Export all DataSet.
499: *
500: * @param array<int, string>|null $fields
501: *
502: * @return array<int, array<string, mixed>>
503: */
504: public function export(Model $model, array $fields = null, bool $typecast = true): array
505: {
506: $data = $model->action('select', [$fields])->getRows();
507:
508: if ($typecast) {
509: $data = array_map(function (array $row) use ($model) {
510: return $this->typecastLoadRow($model, $row);
511: }, $data);
512: }
513:
514: return $data;
515: }
516:
517: /**
518: * @return \Traversable<array<string, mixed>>
519: */
520: public function prepareIterator(Model $model): \Traversable
521: {
522: $export = $model->action('select');
523:
524: try {
525: return $export->getRowsIterator();
526: } catch (SqlException $e) {
527: throw (new Exception('Unable to execute iteration query', 0, $e))
528: ->addMoreInfo('model', $model)
529: ->addMoreInfo('scope', $model->scope()->toWords());
530: }
531: }
532:
533: /**
534: * @param mixed $idRaw
535: */
536: private function assertExactlyOneRecordUpdated(Model $model, $idRaw, int $affectedRows, string $operation): void
537: {
538: if ($affectedRows !== 1) {
539: throw (new Exception(ucfirst($operation) . ' failed, exactly 1 row was expected to be affected'))
540: ->addMoreInfo('model', $model)
541: ->addMoreInfo('scope', $model->scope()->toWords())
542: ->addMoreInfo('idRaw', $idRaw)
543: ->addMoreInfo('affectedRows', $affectedRows);
544: }
545: }
546:
547: /**
548: * @param array<scalar|Expressionable|null> $dataRaw
549: */
550: #[\Override]
551: protected function insertRaw(Model $model, array $dataRaw)
552: {
553: $insert = $this->initQuery($model);
554: $insert->mode('insert');
555:
556: $insert->setMulti($dataRaw);
557:
558: $model->hook(self::HOOK_BEFORE_INSERT_QUERY, [$insert]);
559:
560: try {
561: $c = $insert->executeStatement();
562: } catch (SqlException $e) {
563: throw (new Exception('Unable to execute insert query', 0, $e))
564: ->addMoreInfo('model', $model)
565: ->addMoreInfo('scope', $model->scope()->toWords());
566: }
567:
568: $this->assertExactlyOneRecordUpdated($model, null, $c, 'insert');
569:
570: if ($model->idField) {
571: $idRaw = $dataRaw[$model->getField($model->idField)->getPersistenceName()] ?? null;
572: if ($idRaw === null) {
573: $idRaw = $this->lastInsertId($model);
574: }
575: } else {
576: $idRaw = '';
577: }
578:
579: $model->hook(self::HOOK_AFTER_INSERT_QUERY, [$insert]);
580:
581: return $idRaw;
582: }
583:
584: /**
585: * @param array<scalar|Expressionable|null> $dataRaw
586: */
587: #[\Override]
588: protected function updateRaw(Model $model, $idRaw, array $dataRaw): void
589: {
590: $update = $this->initQuery($model);
591: $update->mode('update');
592:
593: // only apply fields that has been modified
594: $update->setMulti($dataRaw);
595: $update->where($model->getField($model->idField)->getPersistenceName(), $idRaw);
596:
597: $model->hook(self::HOOK_BEFORE_UPDATE_QUERY, [$update]);
598:
599: try {
600: $c = $update->executeStatement();
601: } catch (SqlException $e) {
602: throw (new Exception('Unable to update due to query error', 0, $e))
603: ->addMoreInfo('model', $model)
604: ->addMoreInfo('scope', $model->scope()->toWords());
605: }
606:
607: $this->assertExactlyOneRecordUpdated($model, $idRaw, $c, 'update');
608:
609: $model->hook(self::HOOK_AFTER_UPDATE_QUERY, [$update]);
610: }
611:
612: #[\Override]
613: protected function deleteRaw(Model $model, $idRaw): void
614: {
615: $delete = $this->initQuery($model);
616: $delete->mode('delete');
617: $delete->where($model->getField($model->idField)->getPersistenceName(), $idRaw);
618: $model->hook(self::HOOK_BEFORE_DELETE_QUERY, [$delete]);
619:
620: try {
621: $c = $delete->executeStatement();
622: } catch (SqlException $e) {
623: throw (new Exception('Unable to delete due to query error', 0, $e))
624: ->addMoreInfo('model', $model)
625: ->addMoreInfo('scope', $model->scope()->toWords());
626: }
627:
628: $this->assertExactlyOneRecordUpdated($model, $idRaw, $c, 'delete');
629:
630: $model->hook(self::HOOK_AFTER_DELETE_QUERY, [$delete]);
631: }
632:
633: #[\Override]
634: public function typecastSaveField(Field $field, $value)
635: {
636: $value = parent::typecastSaveField($field, $value);
637:
638: if ($value !== null && $this->binaryTypeIsEncodeNeeded($field->type)) {
639: $value = $this->binaryTypeValueEncode($value);
640: }
641:
642: return $value;
643: }
644:
645: #[\Override]
646: public function typecastLoadField(Field $field, $value)
647: {
648: $value = parent::typecastLoadField($field, $value);
649:
650: if ($value !== null && $this->binaryTypeIsDecodeNeeded($field->type, $value)) {
651: $value = $this->binaryTypeValueDecode($value);
652: }
653:
654: return $value;
655: }
656:
657: #[\Override]
658: protected function _typecastSaveField(Field $field, $value)
659: {
660: $res = parent::_typecastSaveField($field, $value);
661:
662: // Oracle always converts empty string to null
663: // https://stackoverflow.com/questions/13278773/null-vs-empty-string-in-oracle#13278879
664: if ($res === '' && $this->getDatabasePlatform() instanceof OraclePlatform && !$this->binaryTypeIsEncodeNeeded($field->type)) {
665: return null;
666: }
667:
668: return $res;
669: }
670:
671: public function getFieldSqlExpression(Field $field, Expression $expression): Expression
672: {
673: if (isset($field->getOwner()->persistenceData['use_table_prefixes'])) {
674: $mask = '{{}}.{}';
675: $prop = [
676: $field->hasJoin()
677: ? ($field->getJoin()->foreignAlias ?? $field->getJoin()->shortName)
678: : ($field->getOwner()->tableAlias ?? (is_object($field->getOwner()->table) ? '_tm' : $field->getOwner()->table)),
679: $field->getPersistenceName(),
680: ];
681: } else {
682: // references set flag use_table_prefixes, so no need to check them here
683: $mask = '{}';
684: $prop = [
685: $field->getPersistenceName(),
686: ];
687: }
688:
689: // if our Model has expr() method (inherited from Persistence\Sql) then use it
690: if ($field->getOwner()->hasMethod('expr')) {
691: return $field->getOwner()->expr($mask, $prop);
692: }
693:
694: // otherwise call method from expression
695: return $expression->expr($mask, $prop);
696: }
697:
698: public function lastInsertId(Model $model): string
699: {
700: if (is_object($model->table)) {
701: throw new \Error('Table must be a string');
702: }
703:
704: // PostgreSQL and Oracle DBAL platforms use sequence internally for PK autoincrement,
705: // use default name if not set explicitly
706: $sequenceName = null;
707: if ($this->getConnection()->getDatabasePlatform() instanceof PostgreSQLPlatform) {
708: $sequenceName = $this->getConnection()->getDatabasePlatform()->getIdentitySequenceName(
709: $model->table,
710: $model->getField($model->idField)->getPersistenceName()
711: );
712: } elseif ($this->getConnection()->getDatabasePlatform() instanceof OraclePlatform) {
713: $sequenceName = $model->table . '_SEQ';
714: }
715:
716: return $this->getConnection()->lastInsertId($sequenceName);
717: }
718: }
719: